mirror of https://github.com/alibaba/EasyCV.git
83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import copy
|
|
import math
|
|
|
|
import torch
|
|
from mmcv.runner import Hook
|
|
|
|
from easycv.utils import dist_utils, py_util
|
|
from .registry import HOOKS
|
|
|
|
|
|
class ModelEMA:
|
|
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
|
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
|
This is intended to allow functionality like
|
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
A smoothed version of the weights is necessary for some training schemes to perform well.
|
|
This class is sensitive where it is initialized in the sequence of model init,
|
|
GPU assignment and distributed training wrappers.
|
|
|
|
In Yolo5s, ema help increase mAP from 0.27 to 0.353
|
|
"""
|
|
|
|
def __init__(self, model, decay=0.9999, updates=0):
|
|
# Create EMA
|
|
self.model = copy.deepcopy(
|
|
model.module if dist_utils.is_parallel(model) else model).eval(
|
|
) # FP32 EMA
|
|
# if next(model.parameters()).device.type != 'cpu':
|
|
# self.model.half() # FP16 EMA
|
|
self.updates = updates # number of EMA updates
|
|
self.decay = lambda x: decay * (1 - math.exp(
|
|
-x / 2000)) # decay exponential ramp (to help early epochs)
|
|
for p in self.model.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
def update(self, model):
|
|
# Update EMA parameters
|
|
with torch.no_grad():
|
|
self.updates += 1
|
|
d = self.decay(self.updates)
|
|
|
|
msd = model.module.state_dict() if dist_utils.is_parallel(
|
|
model) else model.state_dict() # model state_dict
|
|
for k, v in self.model.state_dict().items():
|
|
if v.dtype.is_floating_point:
|
|
v *= d
|
|
v += (1. - d) * msd[k].detach()
|
|
|
|
def update_attr(self,
|
|
model,
|
|
include=(),
|
|
exclude=('process_group', 'reducer')):
|
|
# Update EMA attributes
|
|
py_util.copy_attr(self.model, model, include, exclude)
|
|
|
|
|
|
@HOOKS.register_module
|
|
class EMAHook(Hook):
|
|
""" Hook to carry out Exponential Moving Average
|
|
"""
|
|
|
|
def __init__(self, decay=0.9999, copy_model_attr=()):
|
|
"""
|
|
Args:
|
|
decay: decay rate for exponetial moving average
|
|
copy_model_attr: attribute to copy from origin model to ema model
|
|
"""
|
|
self.decay = decay
|
|
self._copy_model_attr = copy_model_attr
|
|
self._init_updates = False
|
|
|
|
def before_run(self, runner):
|
|
runner.ema = ModelEMA(runner.model, decay=self.decay)
|
|
|
|
def before_train_epoch(self, runner):
|
|
if not self._init_updates:
|
|
runner.ema.updates = runner.iter
|
|
self._init_updates = True
|
|
|
|
def after_train_iter(self, runner):
|
|
runner.ema.update(runner.model)
|