EasyCV/easycv/hooks/ema_hook.py

83 lines
2.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
import torch
from mmcv.runner import Hook
from easycv.utils import dist_utils, py_util
from .registry import HOOKS
class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
In Yolo5s, ema help increase mAP from 0.27 to 0.353
"""
def __init__(self, model, decay=0.9999, updates=0):
# Create EMA
self.model = copy.deepcopy(
model.module if dist_utils.is_parallel(model) else model).eval(
) # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.model.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(
-x / 2000)) # decay exponential ramp (to help early epochs)
for p in self.model.parameters():
p.requires_grad_(False)
def update(self, model):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = model.module.state_dict() if dist_utils.is_parallel(
model) else model.state_dict() # model state_dict
for k, v in self.model.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1. - d) * msd[k].detach()
def update_attr(self,
model,
include=(),
exclude=('process_group', 'reducer')):
# Update EMA attributes
py_util.copy_attr(self.model, model, include, exclude)
@HOOKS.register_module
class EMAHook(Hook):
""" Hook to carry out Exponential Moving Average
"""
def __init__(self, decay=0.9999, copy_model_attr=()):
"""
Args:
decay: decay rate for exponetial moving average
copy_model_attr: attribute to copy from origin model to ema model
"""
self.decay = decay
self._copy_model_attr = copy_model_attr
self._init_updates = False
def before_run(self, runner):
runner.ema = ModelEMA(runner.model, decay=self.decay)
def before_train_epoch(self, runner):
if not self._init_updates:
runner.ema.updates = runner.iter
self._init_updates = True
def after_train_iter(self, runner):
runner.ema.update(runner.model)