New `smart_resume()` (#8838)
* New `smart_resume()` * Update torch_utils.py * Update torch_utils.py * Update torch_utils.py * fixpull/8788/head
parent
2e10909905
commit
08c8c3e00a
33
train.py
33
train.py
|
@ -54,7 +54,7 @@ from utils.loss import ComputeLoss
|
|||
from utils.metrics import fitness
|
||||
from utils.plots import plot_evolve, plot_labels
|
||||
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
|
||||
torch_distributed_zero_first)
|
||||
smart_resume, torch_distributed_zero_first)
|
||||
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
|
@ -163,26 +163,9 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
ema = ModelEMA(model) if RANK in {-1, 0} else None
|
||||
|
||||
# Resume
|
||||
start_epoch, best_fitness = 0, 0.0
|
||||
best_fitness, start_epoch = 0.0, 0
|
||||
if pretrained:
|
||||
# Optimizer
|
||||
if ckpt['optimizer'] is not None:
|
||||
optimizer.load_state_dict(ckpt['optimizer'])
|
||||
best_fitness = ckpt['best_fitness']
|
||||
|
||||
# EMA
|
||||
if ema and ckpt.get('ema'):
|
||||
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
|
||||
ema.updates = ckpt['updates']
|
||||
|
||||
# Epochs
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
if resume:
|
||||
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
|
||||
if epochs < start_epoch:
|
||||
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
|
||||
epochs += ckpt['epoch'] # finetune additional epochs
|
||||
|
||||
best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
|
||||
del ckpt, csd
|
||||
|
||||
# DP mode
|
||||
|
@ -212,8 +195,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
quad=opt.quad,
|
||||
prefix=colorstr('train: '),
|
||||
shuffle=True)
|
||||
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
|
||||
nb = len(train_loader) # number of batches
|
||||
labels = np.concatenate(dataset.labels, 0)
|
||||
mlc = int(labels[:, 0].max()) # max label class
|
||||
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
|
||||
|
||||
# Process 0
|
||||
|
@ -232,10 +215,6 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
prefix=colorstr('val: '))[0]
|
||||
|
||||
if not resume:
|
||||
labels = np.concatenate(dataset.labels, 0)
|
||||
# c = torch.tensor(labels[:, 0]) # classes
|
||||
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
||||
# model._initialize_biases(cf.to(device))
|
||||
if plots:
|
||||
plot_labels(labels, names, save_dir)
|
||||
|
||||
|
@ -263,6 +242,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|||
|
||||
# Start training
|
||||
t0 = time.time()
|
||||
nb = len(train_loader) # number of batches
|
||||
nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
|
||||
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
|
||||
last_opt_step = -1
|
||||
|
@ -510,7 +490,6 @@ def main(opt, callbacks=Callbacks()):
|
|||
with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
|
||||
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
|
||||
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
|
||||
LOGGER.info(f'Resuming training from {ckpt}')
|
||||
else:
|
||||
opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
|
||||
check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
|
||||
|
|
|
@ -306,6 +306,25 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-
|
|||
return optimizer
|
||||
|
||||
|
||||
def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
|
||||
# Resume training from a partially trained checkpoint
|
||||
best_fitness = 0.0
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
if ckpt['optimizer'] is not None:
|
||||
optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
||||
best_fitness = ckpt['best_fitness']
|
||||
if ema and ckpt.get('ema'):
|
||||
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||
ema.updates = ckpt['updates']
|
||||
if resume:
|
||||
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
|
||||
LOGGER.info(f'Resuming training from {weights} for {epochs - start_epoch} more epochs to {epochs} total epochs')
|
||||
if epochs < start_epoch:
|
||||
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
|
||||
epochs += ckpt['epoch'] # finetune additional epochs
|
||||
return best_fitness, start_epoch, epochs
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
# YOLOv5 simple early stopper
|
||||
def __init__(self, patience=30):
|
||||
|
|
Loading…
Reference in New Issue