mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix EMAHook trigger train loop and AveragedModel sync buffer. (#467)
* [Fix] Fix EMAHook trigger train loop init during testing. * fix sync buffer * update ut * fix sync buffer * fix sync buffer
This commit is contained in:
parent
18a0338c91
commit
8d25dbdeda
@ -71,6 +71,12 @@ class EMAHook(Hook):
|
||||
self.ema_model = MODELS.build(
|
||||
self.ema_cfg, default_args=dict(model=self.src_model))
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Check the begin_epoch/iter is smaller than max_epochs/iters.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.enabled_by_epoch:
|
||||
assert self.begin_epoch <= runner.max_epochs, (
|
||||
'self.begin_epoch should be smaller than runner.max_epochs: '
|
||||
@ -96,6 +102,11 @@ class EMAHook(Hook):
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self.ema_model.update_parameters(self.src_model)
|
||||
else:
|
||||
ema_params = self.ema_model.module.state_dict()
|
||||
src_params = self.src_model.state_dict()
|
||||
for k, p in ema_params.items():
|
||||
p.data.copy_(src_params[k].data)
|
||||
|
||||
def before_val_epoch(self, runner) -> None:
|
||||
"""We load parameter values from ema model to source model before
|
||||
@ -104,8 +115,7 @@ class EMAHook(Hook):
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self._swap_ema_parameters()
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
@ -118,8 +128,7 @@ class EMAHook(Hook):
|
||||
metrics on validation dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self._swap_ema_parameters()
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def before_test_epoch(self, runner) -> None:
|
||||
"""We load parameter values from ema model to source model before test.
|
||||
@ -127,8 +136,7 @@ class EMAHook(Hook):
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self._swap_ema_parameters()
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
@ -141,8 +149,7 @@ class EMAHook(Hook):
|
||||
metrics on test dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self._swap_ema_parameters()
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""Save ema parameters to checkpoint.
|
||||
@ -150,14 +157,13 @@ class EMAHook(Hook):
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
|
||||
# Save ema parameters to the source model's state dict so that we
|
||||
# can directly load the averaged model weights for deployment.
|
||||
# Swapping the state_dict key-values instead of swapping model
|
||||
# parameters because the state_dict is a shallow copy of model
|
||||
# parameters.
|
||||
self._swap_ema_state_dict(checkpoint)
|
||||
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
|
||||
# Save ema parameters to the source model's state dict so that we
|
||||
# can directly load the averaged model weights for deployment.
|
||||
# Swapping the state_dict key-values instead of swapping model
|
||||
# parameters because the state_dict is a shallow copy of model
|
||||
# parameters.
|
||||
self._swap_ema_state_dict(checkpoint)
|
||||
|
||||
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""Resume ema parameters from checkpoint.
|
||||
@ -165,23 +171,22 @@ class EMAHook(Hook):
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
if 'ema_state_dict' in checkpoint:
|
||||
# The original model parameters are actually saved in ema
|
||||
# field swap the weights back to resume ema state.
|
||||
self._swap_ema_state_dict(checkpoint)
|
||||
self.ema_model.load_state_dict(
|
||||
checkpoint['ema_state_dict'], strict=self.strict_load)
|
||||
if 'ema_state_dict' in checkpoint:
|
||||
# The original model parameters are actually saved in ema
|
||||
# field swap the weights back to resume ema state.
|
||||
self._swap_ema_state_dict(checkpoint)
|
||||
self.ema_model.load_state_dict(
|
||||
checkpoint['ema_state_dict'], strict=self.strict_load)
|
||||
|
||||
# Support load checkpoint without ema state dict.
|
||||
else:
|
||||
print_log(
|
||||
'There is no `ema_state_dict` in checkpoint. '
|
||||
'`EMAHook` will make a copy of `state_dict` as the '
|
||||
'initial `ema_state_dict`', 'current', logging.WARNING)
|
||||
self.ema_model.module.load_state_dict(
|
||||
copy.deepcopy(checkpoint['state_dict']),
|
||||
strict=self.strict_load)
|
||||
# Support load checkpoint without ema state dict.
|
||||
else:
|
||||
print_log(
|
||||
'There is no `ema_state_dict` in checkpoint. '
|
||||
'`EMAHook` will make a copy of `state_dict` as the '
|
||||
'initial `ema_state_dict`', 'current', logging.WARNING)
|
||||
self.ema_model.module.load_state_dict(
|
||||
copy.deepcopy(checkpoint['state_dict']),
|
||||
strict=self.strict_load)
|
||||
|
||||
def _swap_ema_parameters(self) -> None:
|
||||
"""Swap the parameter of model with ema_model."""
|
||||
|
@ -106,6 +106,11 @@ class BaseAveragedModel(nn.Module):
|
||||
self.avg_func(p_avg.data,
|
||||
src_parameters[k].data.to(device),
|
||||
self.steps)
|
||||
if not self.update_buffers:
|
||||
# If not update the buffers,
|
||||
# keep the buffers in sync with the source model.
|
||||
for b_avg, b_src in zip(self.module.buffers(), model.buffers()):
|
||||
b_avg.data.copy_(b_src.data.to(b_avg.device))
|
||||
self.steps += 1
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@ from mmengine.model import BaseModel, ExponentialMovingAverage
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import DATASETS, MODEL_WRAPPERS
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
@ -225,9 +226,13 @@ class TestEMAHook(TestCase):
|
||||
custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
|
||||
experiment_name='test6')
|
||||
runner.train()
|
||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'))
|
||||
self.assertNotIn('ema_state_dict', state_dict)
|
||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth'))
|
||||
state_dict = torch.load(
|
||||
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
|
||||
self.assertIn('ema_state_dict', state_dict)
|
||||
for k, v in state_dict['state_dict'].items():
|
||||
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
|
||||
state_dict = torch.load(
|
||||
osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu')
|
||||
self.assertIn('ema_state_dict', state_dict)
|
||||
|
||||
# Test enable ema at 5 iterations.
|
||||
@ -255,7 +260,11 @@ class TestEMAHook(TestCase):
|
||||
custom_hooks=[dict(type='EMAHook', begin_iter=5)],
|
||||
experiment_name='test7')
|
||||
runner.train()
|
||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'))
|
||||
self.assertNotIn('ema_state_dict', state_dict)
|
||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'))
|
||||
state_dict = torch.load(
|
||||
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
|
||||
self.assertIn('ema_state_dict', state_dict)
|
||||
for k, v in state_dict['state_dict'].items():
|
||||
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
|
||||
state_dict = torch.load(
|
||||
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
|
||||
self.assertIn('ema_state_dict', state_dict)
|
||||
|
Loading…
x
Reference in New Issue
Block a user