diff --git a/train.py b/train.py index 99a43f861..17d16dba1 100644 --- a/train.py +++ b/train.py @@ -43,7 +43,7 @@ from utils.autoanchor import check_anchors from utils.autobatch import check_train_batch_size from utils.callbacks import Callbacks from utils.dataloaders import create_dataloader -from utils.downloads import attempt_download +from utils.downloads import attempt_download, is_url from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, @@ -77,6 +77,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio with open(hyp, errors='ignore') as f: hyp = yaml.safe_load(f) # load hyps dict LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) + opt.hyp = hyp.copy() # for saving hyps to checkpoints # Save run settings if not evolve: @@ -377,6 +378,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio 'updates': ema.updates, 'optimizer': optimizer.state_dict(), 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None, + 'opt': vars(opt), 'date': datetime.now().isoformat()} # Save last, best and delete @@ -472,8 +474,7 @@ def parse_opt(known=False): parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval') parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use') - opt = parser.parse_known_args()[0] if known else parser.parse_args() - return opt + return parser.parse_known_args()[0] if known else parser.parse_args() def main(opt, callbacks=Callbacks()): @@ -484,12 +485,20 @@ def main(opt, callbacks=Callbacks()): check_requirements(exclude=['thop']) # Resume - if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run - ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path - assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' - with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f: - opt = argparse.Namespace(**yaml.safe_load(f)) # replace - opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate + if opt.resume and not (check_wandb_resume(opt) or opt.evolve): # resume an interrupted run + last = Path(opt.resume if isinstance(opt.resume, str) else get_latest_run()) # specified or most recent last.pt + assert last.is_file(), f'ERROR: --resume checkpoint {last} does not exist' + opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml + opt_data = opt.data # original dataset + if opt_yaml.is_file(): + with open(opt_yaml, errors='ignore') as f: + d = yaml.safe_load(f) + else: + d = torch.load(last, map_location='cpu')['opt'] + opt = argparse.Namespace(**d) # replace + opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate + if is_url(opt.data): + opt.data = str(opt_data) # avoid HUB resume auth timeout else: opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks diff --git a/utils/downloads.py b/utils/downloads.py index ebe5bd36e..9d4780ad2 100644 --- a/utils/downloads.py +++ b/utils/downloads.py @@ -16,12 +16,14 @@ import requests import torch -def is_url(url): +def is_url(url, check_online=True): # Check if online file exists try: - r = urllib.request.urlopen(url) # response - return r.getcode() == 200 - except urllib.request.HTTPError: + url = str(url) + result = urllib.parse.urlparse(url) + assert all([result.scheme, result.netloc, result.path]) # check if is url + return (urllib.request.urlopen(url).getcode() == 200) if check_online else True # check if exists online + except (AssertionError, urllib.request.HTTPError): return False diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 391ddead2..d5615c263 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -317,8 +317,9 @@ def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, re ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA ema.updates = ckpt['updates'] if resume: - assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' - LOGGER.info(f'Resuming training from {weights} for {epochs - start_epoch} more epochs to {epochs} total epochs') + assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \ + f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'" + LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs') if epochs < start_epoch: LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") epochs += ckpt['epoch'] # finetune additional epochs