[Fix] Fix EMAHook trigger train loop and AveragedModel sync buffer. (#467)

* [Fix] Fix EMAHook trigger train loop init during testing.

* fix sync buffer

* update ut

* fix sync buffer

* fix sync buffer
This commit is contained in:
RangiLyu 2022-08-26 14:21:56 +08:00 committed by GitHub
parent 18a0338c91
commit 8d25dbdeda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 38 deletions

View File

@ -71,6 +71,12 @@ class EMAHook(Hook):
self.ema_model = MODELS.build(
self.ema_cfg, default_args=dict(model=self.src_model))
def before_train(self, runner) -> None:
"""Check the begin_epoch/iter is smaller than max_epochs/iters.
Args:
runner (Runner): The runner of the training process.
"""
if self.enabled_by_epoch:
assert self.begin_epoch <= runner.max_epochs, (
'self.begin_epoch should be smaller than runner.max_epochs: '
@ -96,6 +102,11 @@ class EMAHook(Hook):
"""
if self._ema_started(runner):
self.ema_model.update_parameters(self.src_model)
else:
ema_params = self.ema_model.module.state_dict()
src_params = self.src_model.state_dict()
for k, p in ema_params.items():
p.data.copy_(src_params[k].data)
def before_val_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before
@ -104,8 +115,7 @@ class EMAHook(Hook):
Args:
runner (Runner): The runner of the training process.
"""
if self._ema_started(runner):
self._swap_ema_parameters()
self._swap_ema_parameters()
def after_val_epoch(self,
runner,
@ -118,8 +128,7 @@ class EMAHook(Hook):
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if self._ema_started(runner):
self._swap_ema_parameters()
self._swap_ema_parameters()
def before_test_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before test.
@ -127,8 +136,7 @@ class EMAHook(Hook):
Args:
runner (Runner): The runner of the training process.
"""
if self._ema_started(runner):
self._swap_ema_parameters()
self._swap_ema_parameters()
def after_test_epoch(self,
runner,
@ -141,8 +149,7 @@ class EMAHook(Hook):
metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if self._ema_started(runner):
self._swap_ema_parameters()
self._swap_ema_parameters()
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""Save ema parameters to checkpoint.
@ -150,14 +157,13 @@ class EMAHook(Hook):
Args:
runner (Runner): The runner of the testing process.
"""
if self._ema_started(runner):
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
# Save ema parameters to the source model's state dict so that we
# can directly load the averaged model weights for deployment.
# Swapping the state_dict key-values instead of swapping model
# parameters because the state_dict is a shallow copy of model
# parameters.
self._swap_ema_state_dict(checkpoint)
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
# Save ema parameters to the source model's state dict so that we
# can directly load the averaged model weights for deployment.
# Swapping the state_dict key-values instead of swapping model
# parameters because the state_dict is a shallow copy of model
# parameters.
self._swap_ema_state_dict(checkpoint)
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint.
@ -165,23 +171,22 @@ class EMAHook(Hook):
Args:
runner (Runner): The runner of the testing process.
"""
if self._ema_started(runner):
if 'ema_state_dict' in checkpoint:
# The original model parameters are actually saved in ema
# field swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(
checkpoint['ema_state_dict'], strict=self.strict_load)
if 'ema_state_dict' in checkpoint:
# The original model parameters are actually saved in ema
# field swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(
checkpoint['ema_state_dict'], strict=self.strict_load)
# Support load checkpoint without ema state dict.
else:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
self.ema_model.module.load_state_dict(
copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load)
# Support load checkpoint without ema state dict.
else:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
self.ema_model.module.load_state_dict(
copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load)
def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model."""

View File

@ -106,6 +106,11 @@ class BaseAveragedModel(nn.Module):
self.avg_func(p_avg.data,
src_parameters[k].data.to(device),
self.steps)
if not self.update_buffers:
# If not update the buffers,
# keep the buffers in sync with the source model.
for b_avg, b_src in zip(self.module.buffers(), model.buffers()):
b_avg.data.copy_(b_src.data.to(b_avg.device))
self.steps += 1

View File

@ -14,6 +14,7 @@ from mmengine.model import BaseModel, ExponentialMovingAverage
from mmengine.optim import OptimWrapper
from mmengine.registry import DATASETS, MODEL_WRAPPERS
from mmengine.runner import Runner
from mmengine.testing import assert_allclose
class ToyModel(nn.Module):
@ -225,9 +226,13 @@ class TestEMAHook(TestCase):
custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
experiment_name='test6')
runner.train()
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'))
self.assertNotIn('ema_state_dict', state_dict)
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth'))
state_dict = torch.load(
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
state_dict = torch.load(
osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu')
self.assertIn('ema_state_dict', state_dict)
# Test enable ema at 5 iterations.
@ -255,7 +260,11 @@ class TestEMAHook(TestCase):
custom_hooks=[dict(type='EMAHook', begin_iter=5)],
experiment_name='test7')
runner.train()
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'))
self.assertNotIn('ema_state_dict', state_dict)
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'))
state_dict = torch.load(
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
state_dict = torch.load(
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
self.assertIn('ema_state_dict', state_dict)