From fbf41e09134b113f8e79ae01b4eee40d00797b2d Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Sun, 20 Jun 2021 15:06:58 +0200
Subject: [PATCH] Add `train.run()` method (#3700)

* Update train.py explicit arguments

* Update train.py

* Add run method
---
 train.py | 81 +++++++++++++++++++++++++++++++-------------------------
 1 file changed, 45 insertions(+), 36 deletions(-)

diff --git a/train.py b/train.py
index 68cd7fab5..fbda73208 100644
--- a/train.py
+++ b/train.py
@@ -46,8 +46,9 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
           opt,
           device,
           ):
-    save_dir, epochs, batch_size, weights, single_cls = \
-        opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls
+    save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
+        opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
+        opt.resume, opt.notest, opt.nosave, opt.workers
 
     # Directories
     save_dir = Path(save_dir)
@@ -70,34 +71,34 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
         yaml.safe_dump(vars(opt), f, sort_keys=False)
 
     # Configure
-    plots = not opt.evolve  # create plots
+    plots = not evolve  # create plots
     cuda = device.type != 'cpu'
     init_seeds(2 + RANK)
-    with open(opt.data) as f:
+    with open(data) as f:
         data_dict = yaml.safe_load(f)  # data dict
 
     # Loggers
     loggers = {'wandb': None, 'tb': None}  # loggers dict
     if RANK in [-1, 0]:
         # TensorBoard
-        if not opt.evolve:
+        if not evolve:
             prefix = colorstr('tensorboard: ')
             logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
-            loggers['tb'] = SummaryWriter(opt.save_dir)
+            loggers['tb'] = SummaryWriter(str(save_dir))
 
         # W&B
         opt.hyp = hyp  # add hyperparameters
         run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
         wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
         loggers['wandb'] = wandb_logger.wandb
-        data_dict = wandb_logger.data_dict
-        if wandb_logger.wandb:
+        if loggers['wandb']:
+            data_dict = wandb_logger.data_dict
             weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp  # may update weights, epochs if resuming
 
     nc = 1 if single_cls else int(data_dict['nc'])  # number of classes
     names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names']  # class names
-    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data)  # check
-    is_coco = opt.data.endswith('coco.yaml') and nc == 80  # COCO dataset
+    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data)  # check
+    is_coco = data.endswith('coco.yaml') and nc == 80  # COCO dataset
 
     # Model
     pretrained = weights.endswith('.pt')
@@ -105,14 +106,14 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
         with torch_distributed_zero_first(RANK):
             weights = attempt_download(weights)  # download if not found locally
         ckpt = torch.load(weights, map_location=device)  # load checkpoint
-        model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
-        exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else []  # exclude keys
+        model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
+        exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
         state_dict = ckpt['model'].float().state_dict()  # to FP32
         state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude)  # intersect
         model.load_state_dict(state_dict, strict=False)  # load
         logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights))  # report
     else:
-        model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
+        model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
     with torch_distributed_zero_first(RANK):
         check_dataset(data_dict)  # check
     train_path = data_dict['train']
@@ -182,7 +183,7 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
 
         # Epochs
         start_epoch = ckpt['epoch'] + 1
-        if opt.resume:
+        if resume:
             assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
         if epochs < start_epoch:
             logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
@@ -210,20 +211,20 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
     # Trainloader
     dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
                                             hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
-                                            workers=opt.workers,
+                                            workers=workers,
                                             image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
     mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
     nb = len(dataloader)  # number of batches
-    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
+    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)
 
     # Process 0
     if RANK in [-1, 0]:
         testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
-                                       hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
-                                       workers=opt.workers,
+                                       hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
+                                       workers=workers,
                                        pad=0.5, prefix=colorstr('val: '))[0]
 
-        if not opt.resume:
+        if not resume:
             labels = np.concatenate(dataset.labels, 0)
             c = torch.tensor(labels[:, 0])  # classes
             # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
@@ -356,8 +357,8 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
                         with warnings.catch_warnings():
                             warnings.simplefilter('ignore')  # suppress jit trace warning
                             loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
-                elif plots and ni == 10 and wandb_logger.wandb:
-                    wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
+                elif plots and ni == 10 and loggers['wandb']:
+                    wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
                                                   save_dir.glob('train*.jpg') if x.exists()]})
 
             # end batch ------------------------------------------------------------------------------------------------
@@ -371,7 +372,7 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
             # mAP
             ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
             final_epoch = epoch + 1 == epochs
-            if not opt.notest or final_epoch:  # Calculate mAP
+            if not notest or final_epoch:  # Calculate mAP
                 wandb_logger.current_epoch = epoch + 1
                 results, maps, _ = test.test(data_dict,
                                              batch_size=batch_size // WORLD_SIZE * 2,
@@ -398,7 +399,7 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
             for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
                 if loggers['tb']:
                     loggers['tb'].add_scalar(tag, x, epoch)  # TensorBoard
-                if wandb_logger.wandb:
+                if loggers['wandb']:
                     wandb_logger.log({tag: x})  # W&B
 
             # Update best mAP
@@ -408,7 +409,7 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
             wandb_logger.end_epoch(best_result=best_fitness == fi)
 
             # Save model
-            if (not opt.nosave) or (final_epoch and not opt.evolve):  # if save
+            if (not nosave) or (final_epoch and not evolve):  # if save
                 ckpt = {'epoch': epoch,
                         'best_fitness': best_fitness,
                         'training_results': results_file.read_text(),
@@ -416,13 +417,13 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
                         'ema': deepcopy(ema.ema).half(),
                         'updates': ema.updates,
                         'optimizer': optimizer.state_dict(),
-                        'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
+                        'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
 
                 # Save last, best and delete
                 torch.save(ckpt, last)
                 if best_fitness == fi:
                     torch.save(ckpt, best)
-                if wandb_logger.wandb:
+                if loggers['wandb']:
                     if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
                         wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
                 del ckpt
@@ -433,15 +434,15 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
         logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
         if plots:
             plot_results(save_dir=save_dir)  # save as results.png
-            if wandb_logger.wandb:
+            if loggers['wandb']:
                 files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
-                wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
+                wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
                                               if (save_dir / f).exists()]})
 
-        if not opt.evolve:
+        if not evolve:
             if is_coco:  # COCO dataset
                 for m in [last, best] if best.exists() else [last]:  # speed, mAP tests
-                    results, _, _ = test.test(opt.data,
+                    results, _, _ = test.test(data,
                                               batch_size=batch_size // WORLD_SIZE * 2,
                                               imgsz=imgsz_test,
                                               conf_thres=0.001,
@@ -457,17 +458,17 @@ def train(hyp,  # path/to/hyp.yaml or hyp dictionary
             for f in last, best:
                 if f.exists():
                     strip_optimizer(f)  # strip optimizers
-            if wandb_logger.wandb:  # Log the stripped model
-                wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
-                                                name='run_' + wandb_logger.wandb_run.id + '_model',
-                                                aliases=['latest', 'best', 'stripped'])
+            if loggers['wandb']:  # Log the stripped model
+                loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
+                                              name='run_' + wandb_logger.wandb_run.id + '_model',
+                                              aliases=['latest', 'best', 'stripped'])
         wandb_logger.finish_run()
 
     torch.cuda.empty_cache()
     return results
 
 
-def parse_opt():
+def parse_opt(known=False):
     parser = argparse.ArgumentParser()
     parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
     parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
@@ -503,7 +504,7 @@ def parse_opt():
     parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
     parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
     parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
-    opt = parser.parse_args()
+    opt = parser.parse_known_args()[0] if known else parser.parse_args()
     return opt
 
 
@@ -633,6 +634,14 @@ def main(opt):
               f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
 
 
+def run(**kwargs):
+    # Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
+    opt = parse_opt(True)
+    for k, v in kwargs.items():
+        setattr(opt, k, v)
+    main(opt)
+
+
 if __name__ == "__main__":
     opt = parse_opt()
     main(opt)