Self-contained checkpoint `--resume` (#8839)
* Single checkpoint resume * Update train.py * Add hyp * Add hyp * Add hyp * FIX * avoid resume on url data * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid resume on url data * avoid resume on url data * Update Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9101/head^2
parent
4d8d84b0ea
commit
a75a1105a1
27
train.py
27
train.py
|
@ -43,7 +43,7 @@ from utils.autoanchor import check_anchors
|
|||
from utils.autobatch import check_train_batch_size
|
||||
from utils.callbacks import Callbacks
|
||||
from utils.dataloaders import create_dataloader
|
||||
from utils.downloads import attempt_download
|
||||
from utils.downloads import attempt_download, is_url
|
||||
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
|
||||
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
|
||||
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
|
||||
|
@ -77,6 +77,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
with open(hyp, errors='ignore') as f:
|
||||
hyp = yaml.safe_load(f) # load hyps dict
|
||||
LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
||||
opt.hyp = hyp.copy() # for saving hyps to checkpoints
|
||||
|
||||
# Save run settings
|
||||
if not evolve:
|
||||
|
@ -377,6 +378,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
'updates': ema.updates,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
|
||||
'opt': vars(opt),
|
||||
'date': datetime.now().isoformat()}
|
||||
|
||||
# Save last, best and delete
|
||||
|
@ -472,8 +474,7 @@ def parse_opt(known=False):
|
|||
parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
|
||||
parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
|
||||
|
||||
opt = parser.parse_known_args()[0] if known else parser.parse_args()
|
||||
return opt
|
||||
return parser.parse_known_args()[0] if known else parser.parse_args()
|
||||
|
||||
|
||||
def main(opt, callbacks=Callbacks()):
|
||||
|
@ -484,12 +485,20 @@ def main(opt, callbacks=Callbacks()):
|
|||
check_requirements(exclude=['thop'])
|
||||
|
||||
# Resume
|
||||
if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
|
||||
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
|
||||
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
||||
with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
|
||||
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
|
||||
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
|
||||
if opt.resume and not (check_wandb_resume(opt) or opt.evolve): # resume an interrupted run
|
||||
last = Path(opt.resume if isinstance(opt.resume, str) else get_latest_run()) # specified or most recent last.pt
|
||||
assert last.is_file(), f'ERROR: --resume checkpoint {last} does not exist'
|
||||
opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
|
||||
opt_data = opt.data # original dataset
|
||||
if opt_yaml.is_file():
|
||||
with open(opt_yaml, errors='ignore') as f:
|
||||
d = yaml.safe_load(f)
|
||||
else:
|
||||
d = torch.load(last, map_location='cpu')['opt']
|
||||
opt = argparse.Namespace(**d) # replace
|
||||
opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
|
||||
if is_url(opt.data):
|
||||
opt.data = str(opt_data) # avoid HUB resume auth timeout
|
||||
else:
|
||||
opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
|
||||
check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
|
||||
|
|
|
@ -16,12 +16,14 @@ import requests
|
|||
import torch
|
||||
|
||||
|
||||
def is_url(url):
|
||||
def is_url(url, check_online=True):
|
||||
# Check if online file exists
|
||||
try:
|
||||
r = urllib.request.urlopen(url) # response
|
||||
return r.getcode() == 200
|
||||
except urllib.request.HTTPError:
|
||||
url = str(url)
|
||||
result = urllib.parse.urlparse(url)
|
||||
assert all([result.scheme, result.netloc, result.path]) # check if is url
|
||||
return (urllib.request.urlopen(url).getcode() == 200) if check_online else True # check if exists online
|
||||
except (AssertionError, urllib.request.HTTPError):
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -317,8 +317,9 @@ def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, re
|
|||
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||
ema.updates = ckpt['updates']
|
||||
if resume:
|
||||
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
|
||||
LOGGER.info(f'Resuming training from {weights} for {epochs - start_epoch} more epochs to {epochs} total epochs')
|
||||
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
|
||||
f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
|
||||
LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
|
||||
if epochs < start_epoch:
|
||||
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
|
||||
epochs += ckpt['epoch'] # finetune additional epochs
|
||||
|
|
Loading…
Reference in New Issue