mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Add train.run()
method (#3700)
* Update train.py explicit arguments * Update train.py * Add run method
This commit is contained in:
parent
c1af67dcd4
commit
fbf41e0913
81
train.py
81
train.py
@ -46,8 +46,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
opt,
|
opt,
|
||||||
device,
|
device,
|
||||||
):
|
):
|
||||||
save_dir, epochs, batch_size, weights, single_cls = \
|
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
|
||||||
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls
|
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
||||||
|
opt.resume, opt.notest, opt.nosave, opt.workers
|
||||||
|
|
||||||
# Directories
|
# Directories
|
||||||
save_dir = Path(save_dir)
|
save_dir = Path(save_dir)
|
||||||
@ -70,34 +71,34 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
yaml.safe_dump(vars(opt), f, sort_keys=False)
|
yaml.safe_dump(vars(opt), f, sort_keys=False)
|
||||||
|
|
||||||
# Configure
|
# Configure
|
||||||
plots = not opt.evolve # create plots
|
plots = not evolve # create plots
|
||||||
cuda = device.type != 'cpu'
|
cuda = device.type != 'cpu'
|
||||||
init_seeds(2 + RANK)
|
init_seeds(2 + RANK)
|
||||||
with open(opt.data) as f:
|
with open(data) as f:
|
||||||
data_dict = yaml.safe_load(f) # data dict
|
data_dict = yaml.safe_load(f) # data dict
|
||||||
|
|
||||||
# Loggers
|
# Loggers
|
||||||
loggers = {'wandb': None, 'tb': None} # loggers dict
|
loggers = {'wandb': None, 'tb': None} # loggers dict
|
||||||
if RANK in [-1, 0]:
|
if RANK in [-1, 0]:
|
||||||
# TensorBoard
|
# TensorBoard
|
||||||
if not opt.evolve:
|
if not evolve:
|
||||||
prefix = colorstr('tensorboard: ')
|
prefix = colorstr('tensorboard: ')
|
||||||
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
|
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
|
||||||
loggers['tb'] = SummaryWriter(opt.save_dir)
|
loggers['tb'] = SummaryWriter(str(save_dir))
|
||||||
|
|
||||||
# W&B
|
# W&B
|
||||||
opt.hyp = hyp # add hyperparameters
|
opt.hyp = hyp # add hyperparameters
|
||||||
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
||||||
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
|
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
|
||||||
loggers['wandb'] = wandb_logger.wandb
|
loggers['wandb'] = wandb_logger.wandb
|
||||||
data_dict = wandb_logger.data_dict
|
if loggers['wandb']:
|
||||||
if wandb_logger.wandb:
|
data_dict = wandb_logger.data_dict
|
||||||
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
|
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
|
||||||
|
|
||||||
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
||||||
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||||
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check
|
||||||
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
pretrained = weights.endswith('.pt')
|
pretrained = weights.endswith('.pt')
|
||||||
@ -105,14 +106,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
with torch_distributed_zero_first(RANK):
|
with torch_distributed_zero_first(RANK):
|
||||||
weights = attempt_download(weights) # download if not found locally
|
weights = attempt_download(weights) # download if not found locally
|
||||||
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
||||||
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
||||||
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
|
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
|
||||||
state_dict = ckpt['model'].float().state_dict() # to FP32
|
state_dict = ckpt['model'].float().state_dict() # to FP32
|
||||||
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
|
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
|
||||||
model.load_state_dict(state_dict, strict=False) # load
|
model.load_state_dict(state_dict, strict=False) # load
|
||||||
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
||||||
else:
|
else:
|
||||||
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
||||||
with torch_distributed_zero_first(RANK):
|
with torch_distributed_zero_first(RANK):
|
||||||
check_dataset(data_dict) # check
|
check_dataset(data_dict) # check
|
||||||
train_path = data_dict['train']
|
train_path = data_dict['train']
|
||||||
@ -182,7 +183,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
|
|
||||||
# Epochs
|
# Epochs
|
||||||
start_epoch = ckpt['epoch'] + 1
|
start_epoch = ckpt['epoch'] + 1
|
||||||
if opt.resume:
|
if resume:
|
||||||
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
|
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
|
||||||
if epochs < start_epoch:
|
if epochs < start_epoch:
|
||||||
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
||||||
@ -210,20 +211,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
# Trainloader
|
# Trainloader
|
||||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
||||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
|
||||||
workers=opt.workers,
|
workers=workers,
|
||||||
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
||||||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
||||||
nb = len(dataloader) # number of batches
|
nb = len(dataloader) # number of batches
|
||||||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)
|
||||||
|
|
||||||
# Process 0
|
# Process 0
|
||||||
if RANK in [-1, 0]:
|
if RANK in [-1, 0]:
|
||||||
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
|
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
|
||||||
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
|
||||||
workers=opt.workers,
|
workers=workers,
|
||||||
pad=0.5, prefix=colorstr('val: '))[0]
|
pad=0.5, prefix=colorstr('val: '))[0]
|
||||||
|
|
||||||
if not opt.resume:
|
if not resume:
|
||||||
labels = np.concatenate(dataset.labels, 0)
|
labels = np.concatenate(dataset.labels, 0)
|
||||||
c = torch.tensor(labels[:, 0]) # classes
|
c = torch.tensor(labels[:, 0]) # classes
|
||||||
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
||||||
@ -356,8 +357,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||||
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
||||||
elif plots and ni == 10 and wandb_logger.wandb:
|
elif plots and ni == 10 and loggers['wandb']:
|
||||||
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
|
wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
|
||||||
save_dir.glob('train*.jpg') if x.exists()]})
|
save_dir.glob('train*.jpg') if x.exists()]})
|
||||||
|
|
||||||
# end batch ------------------------------------------------------------------------------------------------
|
# end batch ------------------------------------------------------------------------------------------------
|
||||||
@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
# mAP
|
# mAP
|
||||||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
||||||
final_epoch = epoch + 1 == epochs
|
final_epoch = epoch + 1 == epochs
|
||||||
if not opt.notest or final_epoch: # Calculate mAP
|
if not notest or final_epoch: # Calculate mAP
|
||||||
wandb_logger.current_epoch = epoch + 1
|
wandb_logger.current_epoch = epoch + 1
|
||||||
results, maps, _ = test.test(data_dict,
|
results, maps, _ = test.test(data_dict,
|
||||||
batch_size=batch_size // WORLD_SIZE * 2,
|
batch_size=batch_size // WORLD_SIZE * 2,
|
||||||
@ -398,7 +399,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
||||||
if loggers['tb']:
|
if loggers['tb']:
|
||||||
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
|
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
|
||||||
if wandb_logger.wandb:
|
if loggers['wandb']:
|
||||||
wandb_logger.log({tag: x}) # W&B
|
wandb_logger.log({tag: x}) # W&B
|
||||||
|
|
||||||
# Update best mAP
|
# Update best mAP
|
||||||
@ -408,7 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
||||||
|
|
||||||
# Save model
|
# Save model
|
||||||
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
|
if (not nosave) or (final_epoch and not evolve): # if save
|
||||||
ckpt = {'epoch': epoch,
|
ckpt = {'epoch': epoch,
|
||||||
'best_fitness': best_fitness,
|
'best_fitness': best_fitness,
|
||||||
'training_results': results_file.read_text(),
|
'training_results': results_file.read_text(),
|
||||||
@ -416,13 +417,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
'ema': deepcopy(ema.ema).half(),
|
'ema': deepcopy(ema.ema).half(),
|
||||||
'updates': ema.updates,
|
'updates': ema.updates,
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
|
'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
|
||||||
|
|
||||||
# Save last, best and delete
|
# Save last, best and delete
|
||||||
torch.save(ckpt, last)
|
torch.save(ckpt, last)
|
||||||
if best_fitness == fi:
|
if best_fitness == fi:
|
||||||
torch.save(ckpt, best)
|
torch.save(ckpt, best)
|
||||||
if wandb_logger.wandb:
|
if loggers['wandb']:
|
||||||
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
||||||
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
||||||
del ckpt
|
del ckpt
|
||||||
@ -433,15 +434,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
||||||
if plots:
|
if plots:
|
||||||
plot_results(save_dir=save_dir) # save as results.png
|
plot_results(save_dir=save_dir) # save as results.png
|
||||||
if wandb_logger.wandb:
|
if loggers['wandb']:
|
||||||
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
||||||
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
|
wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
|
||||||
if (save_dir / f).exists()]})
|
if (save_dir / f).exists()]})
|
||||||
|
|
||||||
if not opt.evolve:
|
if not evolve:
|
||||||
if is_coco: # COCO dataset
|
if is_coco: # COCO dataset
|
||||||
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
|
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
|
||||||
results, _, _ = test.test(opt.data,
|
results, _, _ = test.test(data,
|
||||||
batch_size=batch_size // WORLD_SIZE * 2,
|
batch_size=batch_size // WORLD_SIZE * 2,
|
||||||
imgsz=imgsz_test,
|
imgsz=imgsz_test,
|
||||||
conf_thres=0.001,
|
conf_thres=0.001,
|
||||||
@ -457,17 +458,17 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||||||
for f in last, best:
|
for f in last, best:
|
||||||
if f.exists():
|
if f.exists():
|
||||||
strip_optimizer(f) # strip optimizers
|
strip_optimizer(f) # strip optimizers
|
||||||
if wandb_logger.wandb: # Log the stripped model
|
if loggers['wandb']: # Log the stripped model
|
||||||
wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
|
loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
|
||||||
name='run_' + wandb_logger.wandb_run.id + '_model',
|
name='run_' + wandb_logger.wandb_run.id + '_model',
|
||||||
aliases=['latest', 'best', 'stripped'])
|
aliases=['latest', 'best', 'stripped'])
|
||||||
wandb_logger.finish_run()
|
wandb_logger.finish_run()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def parse_opt():
|
def parse_opt(known=False):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
|
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
|
||||||
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
|
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
|
||||||
@ -503,7 +504,7 @@ def parse_opt():
|
|||||||
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
||||||
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
||||||
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_known_args()[0] if known else parser.parse_args()
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
|
||||||
@ -633,6 +634,14 @@ def main(opt):
|
|||||||
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
|
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
|
||||||
|
|
||||||
|
|
||||||
|
def run(**kwargs):
|
||||||
|
# Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
|
||||||
|
opt = parse_opt(True)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(opt, k, v)
|
||||||
|
main(opt)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
opt = parse_opt()
|
opt = parse_opt()
|
||||||
main(opt)
|
main(opt)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user