[Fix] Fix EMAHook trigger train loop and AveragedModel sync buffer. (#467)

* [Fix] Fix EMAHook trigger train loop init during testing.

* fix sync buffer

* update ut

* fix sync buffer

* fix sync buffer
This commit is contained in:
RangiLyu 2022-08-26 14:21:56 +08:00 committed by GitHub
parent 18a0338c91
commit 8d25dbdeda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 38 deletions

View File

@ -71,6 +71,12 @@ class EMAHook(Hook):
self.ema_model = MODELS.build( self.ema_model = MODELS.build(
self.ema_cfg, default_args=dict(model=self.src_model)) self.ema_cfg, default_args=dict(model=self.src_model))
def before_train(self, runner) -> None:
"""Check the begin_epoch/iter is smaller than max_epochs/iters.
Args:
runner (Runner): The runner of the training process.
"""
if self.enabled_by_epoch: if self.enabled_by_epoch:
assert self.begin_epoch <= runner.max_epochs, ( assert self.begin_epoch <= runner.max_epochs, (
'self.begin_epoch should be smaller than runner.max_epochs: ' 'self.begin_epoch should be smaller than runner.max_epochs: '
@ -96,6 +102,11 @@ class EMAHook(Hook):
""" """
if self._ema_started(runner): if self._ema_started(runner):
self.ema_model.update_parameters(self.src_model) self.ema_model.update_parameters(self.src_model)
else:
ema_params = self.ema_model.module.state_dict()
src_params = self.src_model.state_dict()
for k, p in ema_params.items():
p.data.copy_(src_params[k].data)
def before_val_epoch(self, runner) -> None: def before_val_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before """We load parameter values from ema model to source model before
@ -104,8 +115,7 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self._ema_started(runner): self._swap_ema_parameters()
self._swap_ema_parameters()
def after_val_epoch(self, def after_val_epoch(self,
runner, runner,
@ -118,8 +128,7 @@ class EMAHook(Hook):
metrics on validation dataset. The keys are the names of the metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results. metrics, and the values are corresponding results.
""" """
if self._ema_started(runner): self._swap_ema_parameters()
self._swap_ema_parameters()
def before_test_epoch(self, runner) -> None: def before_test_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before test. """We load parameter values from ema model to source model before test.
@ -127,8 +136,7 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self._ema_started(runner): self._swap_ema_parameters()
self._swap_ema_parameters()
def after_test_epoch(self, def after_test_epoch(self,
runner, runner,
@ -141,8 +149,7 @@ class EMAHook(Hook):
metrics on test dataset. The keys are the names of the metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results. metrics, and the values are corresponding results.
""" """
if self._ema_started(runner): self._swap_ema_parameters()
self._swap_ema_parameters()
def before_save_checkpoint(self, runner, checkpoint: dict) -> None: def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""Save ema parameters to checkpoint. """Save ema parameters to checkpoint.
@ -150,14 +157,13 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the testing process. runner (Runner): The runner of the testing process.
""" """
if self._ema_started(runner): checkpoint['ema_state_dict'] = self.ema_model.state_dict()
checkpoint['ema_state_dict'] = self.ema_model.state_dict() # Save ema parameters to the source model's state dict so that we
# Save ema parameters to the source model's state dict so that we # can directly load the averaged model weights for deployment.
# can directly load the averaged model weights for deployment. # Swapping the state_dict key-values instead of swapping model
# Swapping the state_dict key-values instead of swapping model # parameters because the state_dict is a shallow copy of model
# parameters because the state_dict is a shallow copy of model # parameters.
# parameters. self._swap_ema_state_dict(checkpoint)
self._swap_ema_state_dict(checkpoint)
def after_load_checkpoint(self, runner, checkpoint: dict) -> None: def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint. """Resume ema parameters from checkpoint.
@ -165,23 +171,22 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the testing process. runner (Runner): The runner of the testing process.
""" """
if self._ema_started(runner): if 'ema_state_dict' in checkpoint:
if 'ema_state_dict' in checkpoint: # The original model parameters are actually saved in ema
# The original model parameters are actually saved in ema # field swap the weights back to resume ema state.
# field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint)
self._swap_ema_state_dict(checkpoint) self.ema_model.load_state_dict(
self.ema_model.load_state_dict( checkpoint['ema_state_dict'], strict=self.strict_load)
checkpoint['ema_state_dict'], strict=self.strict_load)
# Support load checkpoint without ema state dict. # Support load checkpoint without ema state dict.
else: else:
print_log( print_log(
'There is no `ema_state_dict` in checkpoint. ' 'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the ' '`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING) 'initial `ema_state_dict`', 'current', logging.WARNING)
self.ema_model.module.load_state_dict( self.ema_model.module.load_state_dict(
copy.deepcopy(checkpoint['state_dict']), copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load) strict=self.strict_load)
def _swap_ema_parameters(self) -> None: def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model.""" """Swap the parameter of model with ema_model."""

View File

@ -106,6 +106,11 @@ class BaseAveragedModel(nn.Module):
self.avg_func(p_avg.data, self.avg_func(p_avg.data,
src_parameters[k].data.to(device), src_parameters[k].data.to(device),
self.steps) self.steps)
if not self.update_buffers:
# If not update the buffers,
# keep the buffers in sync with the source model.
for b_avg, b_src in zip(self.module.buffers(), model.buffers()):
b_avg.data.copy_(b_src.data.to(b_avg.device))
self.steps += 1 self.steps += 1

View File

@ -14,6 +14,7 @@ from mmengine.model import BaseModel, ExponentialMovingAverage
from mmengine.optim import OptimWrapper from mmengine.optim import OptimWrapper
from mmengine.registry import DATASETS, MODEL_WRAPPERS from mmengine.registry import DATASETS, MODEL_WRAPPERS
from mmengine.runner import Runner from mmengine.runner import Runner
from mmengine.testing import assert_allclose
class ToyModel(nn.Module): class ToyModel(nn.Module):
@ -225,9 +226,13 @@ class TestEMAHook(TestCase):
custom_hooks=[dict(type='EMAHook', begin_epoch=5)], custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
experiment_name='test6') experiment_name='test6')
runner.train() runner.train()
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth')) state_dict = torch.load(
self.assertNotIn('ema_state_dict', state_dict) osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth')) self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
state_dict = torch.load(
osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu')
self.assertIn('ema_state_dict', state_dict) self.assertIn('ema_state_dict', state_dict)
# Test enable ema at 5 iterations. # Test enable ema at 5 iterations.
@ -255,7 +260,11 @@ class TestEMAHook(TestCase):
custom_hooks=[dict(type='EMAHook', begin_iter=5)], custom_hooks=[dict(type='EMAHook', begin_iter=5)],
experiment_name='test7') experiment_name='test7')
runner.train() runner.train()
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth')) state_dict = torch.load(
self.assertNotIn('ema_state_dict', state_dict) osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth')) self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
state_dict = torch.load(
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
self.assertIn('ema_state_dict', state_dict) self.assertIn('ema_state_dict', state_dict)