mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix EMAHook trigger train loop and AveragedModel sync buffer. (#467)
* [Fix] Fix EMAHook trigger train loop init during testing. * fix sync buffer * update ut * fix sync buffer * fix sync buffer
This commit is contained in:
parent
18a0338c91
commit
8d25dbdeda
@ -71,6 +71,12 @@ class EMAHook(Hook):
|
|||||||
self.ema_model = MODELS.build(
|
self.ema_model = MODELS.build(
|
||||||
self.ema_cfg, default_args=dict(model=self.src_model))
|
self.ema_cfg, default_args=dict(model=self.src_model))
|
||||||
|
|
||||||
|
def before_train(self, runner) -> None:
|
||||||
|
"""Check the begin_epoch/iter is smaller than max_epochs/iters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
"""
|
||||||
if self.enabled_by_epoch:
|
if self.enabled_by_epoch:
|
||||||
assert self.begin_epoch <= runner.max_epochs, (
|
assert self.begin_epoch <= runner.max_epochs, (
|
||||||
'self.begin_epoch should be smaller than runner.max_epochs: '
|
'self.begin_epoch should be smaller than runner.max_epochs: '
|
||||||
@ -96,6 +102,11 @@ class EMAHook(Hook):
|
|||||||
"""
|
"""
|
||||||
if self._ema_started(runner):
|
if self._ema_started(runner):
|
||||||
self.ema_model.update_parameters(self.src_model)
|
self.ema_model.update_parameters(self.src_model)
|
||||||
|
else:
|
||||||
|
ema_params = self.ema_model.module.state_dict()
|
||||||
|
src_params = self.src_model.state_dict()
|
||||||
|
for k, p in ema_params.items():
|
||||||
|
p.data.copy_(src_params[k].data)
|
||||||
|
|
||||||
def before_val_epoch(self, runner) -> None:
|
def before_val_epoch(self, runner) -> None:
|
||||||
"""We load parameter values from ema model to source model before
|
"""We load parameter values from ema model to source model before
|
||||||
@ -104,8 +115,7 @@ class EMAHook(Hook):
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if self._ema_started(runner):
|
self._swap_ema_parameters()
|
||||||
self._swap_ema_parameters()
|
|
||||||
|
|
||||||
def after_val_epoch(self,
|
def after_val_epoch(self,
|
||||||
runner,
|
runner,
|
||||||
@ -118,8 +128,7 @@ class EMAHook(Hook):
|
|||||||
metrics on validation dataset. The keys are the names of the
|
metrics on validation dataset. The keys are the names of the
|
||||||
metrics, and the values are corresponding results.
|
metrics, and the values are corresponding results.
|
||||||
"""
|
"""
|
||||||
if self._ema_started(runner):
|
self._swap_ema_parameters()
|
||||||
self._swap_ema_parameters()
|
|
||||||
|
|
||||||
def before_test_epoch(self, runner) -> None:
|
def before_test_epoch(self, runner) -> None:
|
||||||
"""We load parameter values from ema model to source model before test.
|
"""We load parameter values from ema model to source model before test.
|
||||||
@ -127,8 +136,7 @@ class EMAHook(Hook):
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if self._ema_started(runner):
|
self._swap_ema_parameters()
|
||||||
self._swap_ema_parameters()
|
|
||||||
|
|
||||||
def after_test_epoch(self,
|
def after_test_epoch(self,
|
||||||
runner,
|
runner,
|
||||||
@ -141,8 +149,7 @@ class EMAHook(Hook):
|
|||||||
metrics on test dataset. The keys are the names of the
|
metrics on test dataset. The keys are the names of the
|
||||||
metrics, and the values are corresponding results.
|
metrics, and the values are corresponding results.
|
||||||
"""
|
"""
|
||||||
if self._ema_started(runner):
|
self._swap_ema_parameters()
|
||||||
self._swap_ema_parameters()
|
|
||||||
|
|
||||||
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||||
"""Save ema parameters to checkpoint.
|
"""Save ema parameters to checkpoint.
|
||||||
@ -150,14 +157,13 @@ class EMAHook(Hook):
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the testing process.
|
runner (Runner): The runner of the testing process.
|
||||||
"""
|
"""
|
||||||
if self._ema_started(runner):
|
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
|
||||||
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
|
# Save ema parameters to the source model's state dict so that we
|
||||||
# Save ema parameters to the source model's state dict so that we
|
# can directly load the averaged model weights for deployment.
|
||||||
# can directly load the averaged model weights for deployment.
|
# Swapping the state_dict key-values instead of swapping model
|
||||||
# Swapping the state_dict key-values instead of swapping model
|
# parameters because the state_dict is a shallow copy of model
|
||||||
# parameters because the state_dict is a shallow copy of model
|
# parameters.
|
||||||
# parameters.
|
self._swap_ema_state_dict(checkpoint)
|
||||||
self._swap_ema_state_dict(checkpoint)
|
|
||||||
|
|
||||||
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||||
"""Resume ema parameters from checkpoint.
|
"""Resume ema parameters from checkpoint.
|
||||||
@ -165,23 +171,22 @@ class EMAHook(Hook):
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the testing process.
|
runner (Runner): The runner of the testing process.
|
||||||
"""
|
"""
|
||||||
if self._ema_started(runner):
|
if 'ema_state_dict' in checkpoint:
|
||||||
if 'ema_state_dict' in checkpoint:
|
# The original model parameters are actually saved in ema
|
||||||
# The original model parameters are actually saved in ema
|
# field swap the weights back to resume ema state.
|
||||||
# field swap the weights back to resume ema state.
|
self._swap_ema_state_dict(checkpoint)
|
||||||
self._swap_ema_state_dict(checkpoint)
|
self.ema_model.load_state_dict(
|
||||||
self.ema_model.load_state_dict(
|
checkpoint['ema_state_dict'], strict=self.strict_load)
|
||||||
checkpoint['ema_state_dict'], strict=self.strict_load)
|
|
||||||
|
|
||||||
# Support load checkpoint without ema state dict.
|
# Support load checkpoint without ema state dict.
|
||||||
else:
|
else:
|
||||||
print_log(
|
print_log(
|
||||||
'There is no `ema_state_dict` in checkpoint. '
|
'There is no `ema_state_dict` in checkpoint. '
|
||||||
'`EMAHook` will make a copy of `state_dict` as the '
|
'`EMAHook` will make a copy of `state_dict` as the '
|
||||||
'initial `ema_state_dict`', 'current', logging.WARNING)
|
'initial `ema_state_dict`', 'current', logging.WARNING)
|
||||||
self.ema_model.module.load_state_dict(
|
self.ema_model.module.load_state_dict(
|
||||||
copy.deepcopy(checkpoint['state_dict']),
|
copy.deepcopy(checkpoint['state_dict']),
|
||||||
strict=self.strict_load)
|
strict=self.strict_load)
|
||||||
|
|
||||||
def _swap_ema_parameters(self) -> None:
|
def _swap_ema_parameters(self) -> None:
|
||||||
"""Swap the parameter of model with ema_model."""
|
"""Swap the parameter of model with ema_model."""
|
||||||
|
@ -106,6 +106,11 @@ class BaseAveragedModel(nn.Module):
|
|||||||
self.avg_func(p_avg.data,
|
self.avg_func(p_avg.data,
|
||||||
src_parameters[k].data.to(device),
|
src_parameters[k].data.to(device),
|
||||||
self.steps)
|
self.steps)
|
||||||
|
if not self.update_buffers:
|
||||||
|
# If not update the buffers,
|
||||||
|
# keep the buffers in sync with the source model.
|
||||||
|
for b_avg, b_src in zip(self.module.buffers(), model.buffers()):
|
||||||
|
b_avg.data.copy_(b_src.data.to(b_avg.device))
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from mmengine.model import BaseModel, ExponentialMovingAverage
|
|||||||
from mmengine.optim import OptimWrapper
|
from mmengine.optim import OptimWrapper
|
||||||
from mmengine.registry import DATASETS, MODEL_WRAPPERS
|
from mmengine.registry import DATASETS, MODEL_WRAPPERS
|
||||||
from mmengine.runner import Runner
|
from mmengine.runner import Runner
|
||||||
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
|
||||||
class ToyModel(nn.Module):
|
class ToyModel(nn.Module):
|
||||||
@ -225,9 +226,13 @@ class TestEMAHook(TestCase):
|
|||||||
custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
|
custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
|
||||||
experiment_name='test6')
|
experiment_name='test6')
|
||||||
runner.train()
|
runner.train()
|
||||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'))
|
state_dict = torch.load(
|
||||||
self.assertNotIn('ema_state_dict', state_dict)
|
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
|
||||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth'))
|
self.assertIn('ema_state_dict', state_dict)
|
||||||
|
for k, v in state_dict['state_dict'].items():
|
||||||
|
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
|
||||||
|
state_dict = torch.load(
|
||||||
|
osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu')
|
||||||
self.assertIn('ema_state_dict', state_dict)
|
self.assertIn('ema_state_dict', state_dict)
|
||||||
|
|
||||||
# Test enable ema at 5 iterations.
|
# Test enable ema at 5 iterations.
|
||||||
@ -255,7 +260,11 @@ class TestEMAHook(TestCase):
|
|||||||
custom_hooks=[dict(type='EMAHook', begin_iter=5)],
|
custom_hooks=[dict(type='EMAHook', begin_iter=5)],
|
||||||
experiment_name='test7')
|
experiment_name='test7')
|
||||||
runner.train()
|
runner.train()
|
||||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'))
|
state_dict = torch.load(
|
||||||
self.assertNotIn('ema_state_dict', state_dict)
|
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
|
||||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'))
|
self.assertIn('ema_state_dict', state_dict)
|
||||||
|
for k, v in state_dict['state_dict'].items():
|
||||||
|
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
|
||||||
|
state_dict = torch.load(
|
||||||
|
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
|
||||||
self.assertIn('ema_state_dict', state_dict)
|
self.assertIn('ema_state_dict', state_dict)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user