319 lines
12 KiB
Python
319 lines
12 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os.path as osp
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmengine.config import ConfigDict
|
|
from mmengine.hooks import EMAHook
|
|
from mmengine.model import BaseModel, ExponentialMovingAverage
|
|
from mmengine.registry import MODELS
|
|
from mmengine.testing import RunnerTestCase, assert_allclose
|
|
from mmengine.testing.runner_test_case import ToyModel
|
|
|
|
|
|
class DummyWrapper(BaseModel):
|
|
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
if not isinstance(model, nn.Module):
|
|
model = MODELS.build(model)
|
|
self.module = model
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.module(*args, **kwargs)
|
|
|
|
|
|
class ToyModel2(ToyModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear3 = nn.Linear(2, 1)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return super().forward(*args, **kwargs)
|
|
|
|
|
|
class ToyModel3(ToyModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 1))
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return super().forward(*args, **kwargs)
|
|
|
|
|
|
class TestEMAHook(RunnerTestCase):
|
|
|
|
def setUp(self) -> None:
|
|
MODELS.register_module(name='DummyWrapper', module=DummyWrapper)
|
|
MODELS.register_module(name='ToyModel2', module=ToyModel2)
|
|
MODELS.register_module(name='ToyModel3', module=ToyModel3)
|
|
return super().setUp()
|
|
|
|
def tearDown(self):
|
|
MODELS.module_dict.pop('DummyWrapper')
|
|
MODELS.module_dict.pop('ToyModel2')
|
|
MODELS.module_dict.pop('ToyModel3')
|
|
return super().tearDown()
|
|
|
|
def test_init(self):
|
|
EMAHook()
|
|
|
|
with self.assertRaisesRegex(AssertionError, '`begin_iter` must'):
|
|
EMAHook(begin_iter=-1)
|
|
|
|
with self.assertRaisesRegex(AssertionError, '`begin_epoch` must'):
|
|
EMAHook(begin_epoch=-1)
|
|
|
|
with self.assertRaisesRegex(AssertionError,
|
|
'`begin_iter` and `begin_epoch`'):
|
|
EMAHook(begin_iter=1, begin_epoch=1)
|
|
|
|
def _get_ema_hook(self, runner):
|
|
for hook in runner.hooks:
|
|
if isinstance(hook, EMAHook):
|
|
return hook
|
|
|
|
def test_before_run(self):
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
cfg.custom_hooks = [dict(type='EMAHook')]
|
|
runner = self.build_runner(cfg)
|
|
ema_hook = self._get_ema_hook(runner)
|
|
ema_hook.before_run(runner)
|
|
self.assertIsInstance(ema_hook.ema_model, ExponentialMovingAverage)
|
|
self.assertIs(ema_hook.src_model, runner.model)
|
|
|
|
def test_before_train(self):
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
cfg.custom_hooks = [
|
|
dict(type='EMAHook', begin_epoch=cfg.train_cfg.max_epochs - 1)
|
|
]
|
|
runner = self.build_runner(cfg)
|
|
ema_hook = self._get_ema_hook(runner)
|
|
ema_hook.before_train(runner)
|
|
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
cfg.custom_hooks = [
|
|
dict(type='EMAHook', begin_epoch=cfg.train_cfg.max_epochs + 1)
|
|
]
|
|
runner = self.build_runner(cfg)
|
|
ema_hook = self._get_ema_hook(runner)
|
|
|
|
with self.assertRaisesRegex(AssertionError, 'self.begin_epoch'):
|
|
ema_hook.before_train(runner)
|
|
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.custom_hooks = [
|
|
dict(type='EMAHook', begin_iter=cfg.train_cfg.max_iters + 1)
|
|
]
|
|
runner = self.build_runner(cfg)
|
|
ema_hook = self._get_ema_hook(runner)
|
|
|
|
with self.assertRaisesRegex(AssertionError, 'self.begin_iter'):
|
|
ema_hook.before_train(runner)
|
|
|
|
def test_after_train_iter(self):
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
cfg.custom_hooks = [dict(type='EMAHook')]
|
|
runner = self.build_runner(cfg)
|
|
ema_hook = self._get_ema_hook(runner)
|
|
|
|
ema_hook = self._get_ema_hook(runner)
|
|
ema_hook.before_run(runner)
|
|
ema_hook.before_train(runner)
|
|
|
|
src_model = runner.model
|
|
ema_model = ema_hook.ema_model
|
|
|
|
with torch.no_grad():
|
|
for parameter in src_model.parameters():
|
|
parameter.data.copy_(torch.randn(parameter.shape))
|
|
|
|
ema_hook.after_train_iter(runner, 1)
|
|
for src, ema in zip(src_model.parameters(), ema_model.parameters()):
|
|
assert_allclose(src.data, ema.data)
|
|
|
|
with torch.no_grad():
|
|
for parameter in src_model.parameters():
|
|
parameter.data.copy_(torch.randn(parameter.shape))
|
|
|
|
ema_hook.after_train_iter(runner, 1)
|
|
|
|
for src, ema in zip(src_model.parameters(), ema_model.parameters()):
|
|
self.assertFalse((src.data == ema.data).all())
|
|
|
|
def test_before_val_epoch(self):
|
|
self._test_swap_parameters('before_val_epoch')
|
|
|
|
def test_after_val_epoch(self):
|
|
self._test_swap_parameters('after_val_epoch')
|
|
|
|
def test_before_test_epoch(self):
|
|
self._test_swap_parameters('before_test_epoch')
|
|
|
|
def test_after_test_epoch(self):
|
|
self._test_swap_parameters('after_test_epoch')
|
|
|
|
def test_before_save_checkpoint(self):
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
runner = self.build_runner(cfg)
|
|
checkpoint = dict(state_dict=ToyModel().state_dict())
|
|
ema_hook = EMAHook()
|
|
ema_hook.before_run(runner)
|
|
ema_hook.before_train(runner)
|
|
|
|
ori_checkpoint = copy.deepcopy(checkpoint)
|
|
ema_hook.before_save_checkpoint(runner, checkpoint)
|
|
|
|
for key in ori_checkpoint['state_dict'].keys():
|
|
assert_allclose(
|
|
ori_checkpoint['state_dict'][key].cpu(),
|
|
checkpoint['ema_state_dict'][f'module.{key}'].cpu())
|
|
|
|
assert_allclose(
|
|
ema_hook.ema_model.state_dict()[f'module.{key}'].cpu(),
|
|
checkpoint['state_dict'][key].cpu())
|
|
|
|
def test_after_load_checkpoint(self):
|
|
# Test load a checkpoint without ema_state_dict.
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
runner = self.build_runner(cfg)
|
|
checkpoint = dict(state_dict=ToyModel().state_dict())
|
|
ema_hook = EMAHook()
|
|
ema_hook.before_run(runner)
|
|
ema_hook.before_train(runner)
|
|
ema_hook.after_load_checkpoint(runner, checkpoint)
|
|
|
|
for key in checkpoint['state_dict'].keys():
|
|
assert_allclose(
|
|
checkpoint['state_dict'][key].cpu(),
|
|
ema_hook.ema_model.state_dict()[f'module.{key}'].cpu())
|
|
|
|
# Test a warning should be raised when resuming from a checkpoint
|
|
# without `ema_state_dict`
|
|
runner._resume = True
|
|
ema_hook.after_load_checkpoint(runner, checkpoint)
|
|
with self.assertLogs(runner.logger, level='WARNING') as cm:
|
|
ema_hook.after_load_checkpoint(runner, checkpoint)
|
|
self.assertRegex(cm.records[0].msg, 'There is no `ema_state_dict`')
|
|
|
|
# Check the weight of state_dict and ema_state_dict have been swapped.
|
|
# when runner._resume is True
|
|
runner._resume = True
|
|
checkpoint = dict(
|
|
state_dict=ToyModel().state_dict(),
|
|
ema_state_dict=ExponentialMovingAverage(ToyModel()).state_dict())
|
|
ori_checkpoint = copy.deepcopy(checkpoint)
|
|
ema_hook.after_load_checkpoint(runner, checkpoint)
|
|
for key in ori_checkpoint['state_dict'].keys():
|
|
assert_allclose(
|
|
ori_checkpoint['state_dict'][key].cpu(),
|
|
ema_hook.ema_model.state_dict()[f'module.{key}'].cpu())
|
|
|
|
runner._resume = False
|
|
ema_hook.after_load_checkpoint(runner, checkpoint)
|
|
|
|
def test_with_runner(self):
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
cfg.custom_hooks = [ConfigDict(type='EMAHook')]
|
|
runner = self.build_runner(cfg)
|
|
ema_hook = self._get_ema_hook(runner)
|
|
runner.train()
|
|
self.assertTrue(
|
|
isinstance(ema_hook.ema_model, ExponentialMovingAverage))
|
|
|
|
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
|
self.assertTrue('ema_state_dict' in checkpoint)
|
|
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
|
|
|
|
# load and testing
|
|
cfg.load_from = osp.join(self.temp_dir.name, 'epoch_2.pth')
|
|
runner = self.build_runner(cfg)
|
|
runner.test()
|
|
|
|
# with model wrapper
|
|
cfg.model = ConfigDict(type='DummyWrapper', model=cfg.model)
|
|
runner = self.build_runner(cfg)
|
|
runner.test()
|
|
|
|
# Test load checkpoint without ema_state_dict
|
|
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
|
checkpoint.pop('ema_state_dict')
|
|
torch.save(checkpoint,
|
|
osp.join(self.temp_dir.name, 'without_ema_state_dict.pth'))
|
|
|
|
cfg.load_from = osp.join(self.temp_dir.name,
|
|
'without_ema_state_dict.pth')
|
|
runner = self.build_runner(cfg)
|
|
runner.test()
|
|
|
|
# Test does not load checkpoint strictly (different name).
|
|
# Test load checkpoint without ema_state_dict
|
|
cfg.model = ConfigDict(type='ToyModel2')
|
|
cfg.custom_hooks = [ConfigDict(type='EMAHook', strict_load=False)]
|
|
runner = self.build_runner(cfg)
|
|
runner.test()
|
|
|
|
# Test does not load ckpt strictly (different weight size).
|
|
# Test load checkpoint without ema_state_dict
|
|
cfg.model = ConfigDict(type='ToyModel3')
|
|
runner = self.build_runner(cfg)
|
|
runner.test()
|
|
|
|
# Test enable ema at 5 epochs.
|
|
cfg.train_cfg.max_epochs = 10
|
|
cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_epoch=5)]
|
|
runner = self.build_runner(cfg)
|
|
runner.train()
|
|
state_dict = torch.load(
|
|
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
|
|
self.assertIn('ema_state_dict', state_dict)
|
|
for k, v in state_dict['state_dict'].items():
|
|
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
|
|
|
|
# Test enable ema at 5 iterations.
|
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
|
cfg.train_cfg.val_interval = 1
|
|
cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_iter=5)]
|
|
cfg.default_hooks.checkpoint.interval = 1
|
|
runner = self.build_runner(cfg)
|
|
runner.train()
|
|
state_dict = torch.load(
|
|
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
|
|
self.assertIn('ema_state_dict', state_dict)
|
|
for k, v in state_dict['state_dict'].items():
|
|
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
|
|
state_dict = torch.load(
|
|
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
|
|
self.assertIn('ema_state_dict', state_dict)
|
|
|
|
def _test_swap_parameters(self, func_name, *args, **kwargs):
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
cfg.custom_hooks = [dict(type='EMAHook')]
|
|
runner = self.build_runner(cfg)
|
|
ema_hook = self._get_ema_hook(runner)
|
|
|
|
runner.train()
|
|
|
|
with torch.no_grad():
|
|
for parameter in ema_hook.src_model.parameters():
|
|
parameter.data.copy_(torch.randn(parameter.shape))
|
|
|
|
src_model = copy.deepcopy(runner.model)
|
|
ema_model = copy.deepcopy(ema_hook.ema_model)
|
|
|
|
func = getattr(ema_hook, func_name)
|
|
func(runner, *args, **kwargs)
|
|
|
|
swapped_src = ema_hook.src_model
|
|
swapped_ema = ema_hook.ema_model
|
|
|
|
for src, ema, swapped_src, swapped_ema in zip(
|
|
src_model.parameters(), ema_model.parameters(),
|
|
swapped_src.parameters(), swapped_ema.parameters()):
|
|
self.assertTrue((src.data == swapped_ema.data).all())
|
|
self.assertTrue((ema.data == swapped_src.data).all())
|