mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
--resume EMA fix #292
This commit is contained in:
parent
2b6209a9d5
commit
24c5a941f0
4
train.py
4
train.py
@ -163,6 +163,7 @@ def train(hyp):
|
|||||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
||||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
||||||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
||||||
|
nb = len(dataloader) # number of batches
|
||||||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
||||||
|
|
||||||
# Testloader
|
# Testloader
|
||||||
@ -191,11 +192,10 @@ def train(hyp):
|
|||||||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
||||||
|
|
||||||
# Exponential moving average
|
# Exponential moving average
|
||||||
ema = torch_utils.ModelEMA(model)
|
ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
nb = len(dataloader) # number of batches
|
|
||||||
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
||||||
|
@ -191,15 +191,11 @@ class ModelEMA:
|
|||||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, decay=0.9999, device=''):
|
def __init__(self, model, decay=0.9999, updates=0):
|
||||||
# Create EMA
|
# Create EMA
|
||||||
self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA
|
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
||||||
self.ema.eval()
|
self.updates = updates # number of EMA updates
|
||||||
self.updates = 0 # number of EMA updates
|
|
||||||
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
||||||
self.device = device # perform ema on different device from model if set
|
|
||||||
if device:
|
|
||||||
self.ema.to(device)
|
|
||||||
for p in self.ema.parameters():
|
for p in self.ema.parameters():
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user