EMA bug fix #279
parent
a9d20eba3e
commit
df224a0d8f
4
train.py
4
train.py
|
@ -294,7 +294,7 @@ def train(hyp):
|
|||
batch_size=batch_size,
|
||||
imgsz=imgsz_test,
|
||||
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
||||
model=ema.ema.module if hasattr(model, 'module') else ema.ema,
|
||||
model=ema.ema,
|
||||
single_cls=opt.single_cls,
|
||||
dataloader=testloader)
|
||||
|
||||
|
@ -324,7 +324,7 @@ def train(hyp):
|
|||
ckpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
'training_results': f.read(),
|
||||
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
|
||||
'model': ema.ema,
|
||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||
|
||||
# Save last, best and delete
|
||||
|
|
|
@ -175,8 +175,8 @@ class ModelEMA:
|
|||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9999, device=''):
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.ema = deepcopy(model)
|
||||
# Create EMA
|
||||
self.ema = deepcopy(model.module if is_parallel(model) else model).half() # FP16 EMA
|
||||
self.ema.eval()
|
||||
self.updates = 0 # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
||||
|
@ -187,22 +187,19 @@ class ModelEMA:
|
|||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
# Update EMA parameters
|
||||
with torch.no_grad():
|
||||
if is_parallel(model):
|
||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||
else:
|
||||
msd, esd = model.state_dict(), self.ema.state_dict()
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
for k, v in esd.items():
|
||||
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if v.dtype.is_floating_point:
|
||||
v *= d
|
||||
v += (1. - d) * msd[k].detach()
|
||||
|
||||
def update_attr(self, model):
|
||||
# Update class attributes
|
||||
ema = self.ema.module if is_parallel(model) else self.ema
|
||||
# Update EMA attributes
|
||||
for k, v in model.__dict__.items():
|
||||
if not k.startswith('_') and k != 'module':
|
||||
setattr(ema, k, v)
|
||||
setattr(self.ema, k, v)
|
||||
|
|
Loading…
Reference in New Issue