Improved model+EMA checkpointing (#2292)
* Enhanced model+EMA checkpointing * update * bug fix * bug fix 2 * always save optimizer * ema half * remove model.float() * model half * carry ema/model in fp32 * rm model.float() * both to float always * cleanup * cleanuppull/2295/head
parent
ca5b10b759
commit
ec1d8496ba
1
test.py
1
test.py
|
@ -272,7 +272,6 @@ def test(data,
|
|||
if not training:
|
||||
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
||||
print(f"Results saved to {save_dir}{s}")
|
||||
model.float() # for training
|
||||
maps = np.zeros(nc) + map
|
||||
for i, c in enumerate(ap_class):
|
||||
maps[c] = ap[i]
|
||||
|
|
25
train.py
25
train.py
|
@ -31,7 +31,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
|
|||
from utils.google_utils import attempt_download
|
||||
from utils.loss import ComputeLoss
|
||||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
||||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
|
||||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
|
||||
loggers = {'wandb': wandb} # loggers dict
|
||||
|
||||
# EMA
|
||||
ema = ModelEMA(model) if rank in [-1, 0] else None
|
||||
|
||||
# Resume
|
||||
start_epoch, best_fitness = 0, 0.0
|
||||
if pretrained:
|
||||
|
@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
optimizer.load_state_dict(ckpt['optimizer'])
|
||||
best_fitness = ckpt['best_fitness']
|
||||
|
||||
# EMA
|
||||
if ema and ckpt.get('ema'):
|
||||
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
|
||||
ema.updates = ckpt['ema'][1]
|
||||
|
||||
# Results
|
||||
if ckpt.get('training_results') is not None:
|
||||
results_file.write_text(ckpt['training_results']) # write results.txt
|
||||
|
@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
||||
logger.info('Using SyncBatchNorm()')
|
||||
|
||||
# EMA
|
||||
ema = ModelEMA(model) if rank in [-1, 0] else None
|
||||
|
||||
# DDP mode
|
||||
if cuda and rank != -1:
|
||||
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
|
||||
|
@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
|
||||
# Process 0
|
||||
if rank in [-1, 0]:
|
||||
ema.updates = start_epoch * nb // accumulate # set EMA updates
|
||||
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
|
||||
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
||||
world_size=opt.world_size, workers=opt.workers,
|
||||
|
@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
# DDP process 0 or single-GPU
|
||||
if rank in [-1, 0]:
|
||||
# mAP
|
||||
if ema:
|
||||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
||||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not opt.notest or final_epoch: # Calculate mAP
|
||||
results, maps, times = test.test(opt.data,
|
||||
|
@ -378,8 +381,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
ckpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
'training_results': results_file.read_text(),
|
||||
'model': ema.ema,
|
||||
'optimizer': None if final_epoch else optimizer.state_dict(),
|
||||
'model': (model.module if is_parallel(model) else model).half(),
|
||||
'ema': (ema.ema.half(), ema.updates),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'wandb_id': wandb_run.id if wandb else None}
|
||||
|
||||
# Save last, best and delete
|
||||
|
@ -387,6 +391,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
if best_fitness == fi:
|
||||
torch.save(ckpt, best)
|
||||
del ckpt
|
||||
|
||||
model.float(), ema.ema.float()
|
||||
|
||||
# end epoch ----------------------------------------------------------------------------------------------------
|
||||
# end training
|
||||
|
||||
|
|
|
@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
|||
def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
|
||||
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
for key in 'optimizer', 'training_results', 'wandb_id':
|
||||
x[key] = None
|
||||
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
|
||||
x[k] = None
|
||||
x['epoch'] = -1
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
|
|
Loading…
Reference in New Issue