mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor ema hook (#804)
* Refacot ema hook unit test * Refacot ema hook unit test * Enhance test_after_load_checkpoint * refine error messsage * Refine as comment --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Fix unit test
This commit is contained in:
parent
aa69ba1a86
commit
b14c179fad
@ -50,9 +50,11 @@ class EMAHook(Hook):
|
||||
assert not (begin_iter != 0 and begin_epoch != 0), (
|
||||
'`begin_iter` and `begin_epoch` should not be both set.')
|
||||
assert begin_iter >= 0, (
|
||||
f'begin_iter must larger than 0, but got begin: {begin_iter}')
|
||||
'`begin_iter` must larger than or equal to 0, '
|
||||
f'but got begin_iter: {begin_iter}')
|
||||
assert begin_epoch >= 0, (
|
||||
f'begin_epoch must larger than 0, but got begin: {begin_epoch}')
|
||||
'`begin_epoch` must larger than or equal to 0, '
|
||||
f'but got begin_epoch: {begin_epoch}')
|
||||
self.begin_iter = begin_iter
|
||||
self.begin_epoch = begin_epoch
|
||||
# If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be
|
||||
@ -80,12 +82,14 @@ class EMAHook(Hook):
|
||||
"""
|
||||
if self.enabled_by_epoch:
|
||||
assert self.begin_epoch <= runner.max_epochs, (
|
||||
'self.begin_epoch should be smaller than runner.max_epochs: '
|
||||
f'{runner.max_epochs}, but got begin: {self.begin_epoch}')
|
||||
'self.begin_epoch should be smaller than or equal to '
|
||||
f'runner.max_epochs: {runner.max_epochs}, but got '
|
||||
f'begin_epoch: {self.begin_epoch}')
|
||||
else:
|
||||
assert self.begin_iter <= runner.max_iters, (
|
||||
'self.begin_iter should be smaller than runner.max_iters: '
|
||||
f'{runner.max_iters}, but got begin: {self.begin_iter}')
|
||||
'self.begin_iter should be smaller than or equal to '
|
||||
f'runner.max_iters: {runner.max_iters}, but got '
|
||||
f'begin_iter: {self.begin_iter}')
|
||||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
|
@ -1,58 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import copy
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.hooks import EMAHook
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModel, ExponentialMovingAverage
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import DATASETS, MODEL_WRAPPERS
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.registry import MODELS
|
||||
from mmengine.testing import RunnerTestCase, assert_allclose
|
||||
from mmengine.testing.runner_test_case import ToyModel
|
||||
|
||||
|
||||
class ToyModel(BaseModel):
|
||||
class DummyWrapper(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
labels = torch.stack(data_sample)
|
||||
inputs = torch.stack(inputs)
|
||||
outputs = self.linear(inputs)
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class ToyModel1(ToyModel):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
if not isinstance(model, nn.Module):
|
||||
model = MODELS.build(model)
|
||||
self.module = model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
|
||||
class ToyModel2(ToyModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(2, 1)
|
||||
self.linear3 = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -62,239 +39,247 @@ class ToyModel3(ToyModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 2)
|
||||
self.linear2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 1))
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class DummyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
class TestEMAHook(RunnerTestCase):
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class TestEMAHook(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
def setUp(self) -> None:
|
||||
MODELS.register_module(name='DummyWrapper', module=DummyWrapper)
|
||||
MODELS.register_module(name='ToyModel2', module=ToyModel2)
|
||||
MODELS.register_module(name='ToyModel3', module=ToyModel3)
|
||||
return super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||
# delete the temporary directory
|
||||
logging.shutdown()
|
||||
MMLogger._instance_dict.clear()
|
||||
self.temp_dir.cleanup()
|
||||
MODELS.module_dict.pop('DummyWrapper')
|
||||
MODELS.module_dict.pop('ToyModel2')
|
||||
MODELS.module_dict.pop('ToyModel3')
|
||||
return super().tearDown()
|
||||
|
||||
def test_ema_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
model = ToyModel1().to(device)
|
||||
evaluator = Evaluator([])
|
||||
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
|
||||
runner = Runner(
|
||||
model=model,
|
||||
train_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
work_dir=self.temp_dir.name,
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', )],
|
||||
experiment_name='test1')
|
||||
runner.train()
|
||||
def test_init(self):
|
||||
EMAHook()
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, '`begin_iter` must'):
|
||||
EMAHook(begin_iter=-1)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, '`begin_epoch` must'):
|
||||
EMAHook(begin_epoch=-1)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'`begin_iter` and `begin_epoch`'):
|
||||
EMAHook(begin_iter=1, begin_epoch=1)
|
||||
|
||||
def _get_ema_hook(self, runner):
|
||||
for hook in runner.hooks:
|
||||
if isinstance(hook, EMAHook):
|
||||
self.assertTrue(
|
||||
isinstance(hook.ema_model, ExponentialMovingAverage))
|
||||
return hook
|
||||
|
||||
def test_before_run(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.custom_hooks = [dict(type='EMAHook')]
|
||||
runner = self.build_runner(cfg)
|
||||
ema_hook = self._get_ema_hook(runner)
|
||||
ema_hook.before_run(runner)
|
||||
self.assertIsInstance(ema_hook.ema_model, ExponentialMovingAverage)
|
||||
self.assertIs(ema_hook.src_model, runner.model)
|
||||
|
||||
def test_before_train(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.custom_hooks = [
|
||||
dict(type='EMAHook', begin_epoch=cfg.train_cfg.max_epochs - 1)
|
||||
]
|
||||
runner = self.build_runner(cfg)
|
||||
ema_hook = self._get_ema_hook(runner)
|
||||
ema_hook.before_train(runner)
|
||||
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.custom_hooks = [
|
||||
dict(type='EMAHook', begin_epoch=cfg.train_cfg.max_epochs + 1)
|
||||
]
|
||||
runner = self.build_runner(cfg)
|
||||
ema_hook = self._get_ema_hook(runner)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'self.begin_epoch'):
|
||||
ema_hook.before_train(runner)
|
||||
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.custom_hooks = [
|
||||
dict(type='EMAHook', begin_iter=cfg.train_cfg.max_iters + 1)
|
||||
]
|
||||
runner = self.build_runner(cfg)
|
||||
ema_hook = self._get_ema_hook(runner)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'self.begin_iter'):
|
||||
ema_hook.before_train(runner)
|
||||
|
||||
def test_after_train_iter(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.custom_hooks = [dict(type='EMAHook')]
|
||||
runner = self.build_runner(cfg)
|
||||
ema_hook = self._get_ema_hook(runner)
|
||||
|
||||
ema_hook = self._get_ema_hook(runner)
|
||||
ema_hook.before_run(runner)
|
||||
ema_hook.before_train(runner)
|
||||
|
||||
src_model = runner.model
|
||||
ema_model = ema_hook.ema_model
|
||||
|
||||
with torch.no_grad():
|
||||
for parameter in src_model.parameters():
|
||||
parameter.data.copy_(torch.randn(parameter.shape))
|
||||
|
||||
ema_hook.after_train_iter(runner, 1)
|
||||
for src, ema in zip(src_model.parameters(), ema_model.parameters()):
|
||||
assert_allclose(src.data, ema.data)
|
||||
|
||||
with torch.no_grad():
|
||||
for parameter in src_model.parameters():
|
||||
parameter.data.copy_(torch.randn(parameter.shape))
|
||||
|
||||
ema_hook.after_train_iter(runner, 1)
|
||||
|
||||
for src, ema in zip(src_model.parameters(), ema_model.parameters()):
|
||||
self.assertFalse((src.data == ema.data).all())
|
||||
|
||||
def test_before_val_epoch(self):
|
||||
self._test_swap_parameters('before_val_epoch')
|
||||
|
||||
def test_after_val_epoch(self):
|
||||
self._test_swap_parameters('after_val_epoch')
|
||||
|
||||
def test_before_test_epoch(self):
|
||||
self._test_swap_parameters('before_test_epoch')
|
||||
|
||||
def test_after_test_epoch(self):
|
||||
self._test_swap_parameters('after_test_epoch')
|
||||
|
||||
def test_before_save_checkpoint(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
checkpoint = dict(state_dict=ToyModel().state_dict())
|
||||
ema_hook = EMAHook()
|
||||
ema_hook.before_run(runner)
|
||||
ema_hook.before_train(runner)
|
||||
|
||||
ori_checkpoint = copy.deepcopy(checkpoint)
|
||||
ema_hook.before_save_checkpoint(runner, checkpoint)
|
||||
|
||||
for key in ori_checkpoint['state_dict'].keys():
|
||||
assert_allclose(
|
||||
ori_checkpoint['state_dict'][key].cpu(),
|
||||
checkpoint['ema_state_dict'][f'module.{key}'].cpu())
|
||||
|
||||
assert_allclose(
|
||||
ema_hook.ema_model.state_dict()[f'module.{key}'].cpu(),
|
||||
checkpoint['state_dict'][key].cpu())
|
||||
|
||||
def test_after_load_checkpoint(self):
|
||||
# Test load a checkpoint without ema_state_dict.
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
checkpoint = dict(state_dict=ToyModel().state_dict())
|
||||
ema_hook = EMAHook()
|
||||
ema_hook.before_run(runner)
|
||||
ema_hook.before_train(runner)
|
||||
ema_hook.after_load_checkpoint(runner, checkpoint)
|
||||
|
||||
for key in checkpoint['state_dict'].keys():
|
||||
assert_allclose(
|
||||
checkpoint['state_dict'][key].cpu(),
|
||||
ema_hook.ema_model.state_dict()[f'module.{key}'].cpu())
|
||||
|
||||
# Test a warning should be raised when resuming from a checkpoint
|
||||
# without `ema_state_dict`
|
||||
runner._resume = True
|
||||
ema_hook.after_load_checkpoint(runner, checkpoint)
|
||||
with self.assertLogs(runner.logger, level='WARNING') as cm:
|
||||
ema_hook.after_load_checkpoint(runner, checkpoint)
|
||||
self.assertRegex(cm.records[0].msg, 'There is no `ema_state_dict`')
|
||||
|
||||
# Check the weight of state_dict and ema_state_dict have been swapped.
|
||||
# when runner._resume is True
|
||||
runner._resume = True
|
||||
checkpoint = dict(
|
||||
state_dict=ToyModel().state_dict(),
|
||||
ema_state_dict=ExponentialMovingAverage(ToyModel()).state_dict())
|
||||
ori_checkpoint = copy.deepcopy(checkpoint)
|
||||
ema_hook.after_load_checkpoint(runner, checkpoint)
|
||||
for key in ori_checkpoint['state_dict'].keys():
|
||||
assert_allclose(
|
||||
ori_checkpoint['state_dict'][key].cpu(),
|
||||
ema_hook.ema_model.state_dict()[f'module.{key}'].cpu())
|
||||
|
||||
runner._resume = False
|
||||
ema_hook.after_load_checkpoint(runner, checkpoint)
|
||||
|
||||
def test_with_runner(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.custom_hooks = [ConfigDict(type='EMAHook')]
|
||||
runner = self.build_runner(cfg)
|
||||
ema_hook = self._get_ema_hook(runner)
|
||||
runner.train()
|
||||
self.assertTrue(
|
||||
osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth')))
|
||||
isinstance(ema_hook.ema_model, ExponentialMovingAverage))
|
||||
|
||||
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
||||
self.assertTrue('ema_state_dict' in checkpoint)
|
||||
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
|
||||
|
||||
# load and testing
|
||||
runner = Runner(
|
||||
model=model,
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook')],
|
||||
experiment_name='test2')
|
||||
cfg.load_from = osp.join(self.temp_dir.name, 'epoch_2.pth')
|
||||
runner = self.build_runner(cfg)
|
||||
runner.test()
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class DummyWrapper(BaseModel):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.module = model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
# with model wrapper
|
||||
runner = Runner(
|
||||
model=DummyWrapper(ToyModel()),
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook')],
|
||||
experiment_name='test3')
|
||||
cfg.model = ConfigDict(type='DummyWrapper', model=cfg.model)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.test()
|
||||
|
||||
# Test load checkpoint without ema_state_dict
|
||||
ckpt = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
||||
ckpt.pop('ema_state_dict')
|
||||
torch.save(ckpt,
|
||||
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
||||
checkpoint.pop('ema_state_dict')
|
||||
torch.save(checkpoint,
|
||||
osp.join(self.temp_dir.name, 'without_ema_state_dict.pth'))
|
||||
runner = Runner(
|
||||
model=DummyWrapper(ToyModel()),
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=osp.join(self.temp_dir.name,
|
||||
'without_ema_state_dict.pth'),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook')],
|
||||
experiment_name='test4')
|
||||
|
||||
cfg.load_from = osp.join(self.temp_dir.name,
|
||||
'without_ema_state_dict.pth')
|
||||
runner = self.build_runner(cfg)
|
||||
runner.test()
|
||||
|
||||
# Test does not load ckpt strict_loadly.
|
||||
# Test does not load checkpoint strictly (different name).
|
||||
# Test load checkpoint without ema_state_dict
|
||||
runner = Runner(
|
||||
model=ToyModel2(),
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', strict_load=False)],
|
||||
experiment_name='test5')
|
||||
cfg.model = ConfigDict(type='ToyModel2')
|
||||
cfg.custom_hooks = [ConfigDict(type='EMAHook', strict_load=False)]
|
||||
runner = self.build_runner(cfg)
|
||||
runner.test()
|
||||
|
||||
# Test does not load ckpt strict_loadly.
|
||||
# Test does not load ckpt strictly (different weight size).
|
||||
# Test load checkpoint without ema_state_dict
|
||||
# Test with different size head.
|
||||
runner = Runner(
|
||||
model=ToyModel3(),
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=osp.join(self.temp_dir.name,
|
||||
'without_ema_state_dict.pth'),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', strict_load=False)],
|
||||
experiment_name='test5.1')
|
||||
cfg.model = ConfigDict(type='ToyModel3')
|
||||
runner = self.build_runner(cfg)
|
||||
runner.test()
|
||||
|
||||
# Test enable ema at 5 epochs.
|
||||
runner = Runner(
|
||||
model=model,
|
||||
train_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
work_dir=self.temp_dir.name,
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
|
||||
experiment_name='test6')
|
||||
cfg.train_cfg.max_epochs = 10
|
||||
cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_epoch=5)]
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
state_dict = torch.load(
|
||||
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
|
||||
self.assertIn('ema_state_dict', state_dict)
|
||||
for k, v in state_dict['state_dict'].items():
|
||||
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
|
||||
state_dict = torch.load(
|
||||
osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu')
|
||||
self.assertIn('ema_state_dict', state_dict)
|
||||
|
||||
# Test enable ema at 5 iterations.
|
||||
runner = Runner(
|
||||
model=model,
|
||||
train_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
work_dir=self.temp_dir.name,
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(by_epoch=False, max_iters=10, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
default_hooks=dict(
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', interval=1, by_epoch=False)),
|
||||
custom_hooks=[dict(type='EMAHook', begin_iter=5)],
|
||||
experiment_name='test7')
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.train_cfg.val_interval = 1
|
||||
cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_iter=5)]
|
||||
cfg.default_hooks.checkpoint.interval = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
state_dict = torch.load(
|
||||
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
|
||||
@ -304,3 +289,30 @@ class TestEMAHook(TestCase):
|
||||
state_dict = torch.load(
|
||||
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
|
||||
self.assertIn('ema_state_dict', state_dict)
|
||||
|
||||
def _test_swap_parameters(self, func_name, *args, **kwargs):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.custom_hooks = [dict(type='EMAHook')]
|
||||
runner = self.build_runner(cfg)
|
||||
ema_hook = self._get_ema_hook(runner)
|
||||
|
||||
runner.train()
|
||||
|
||||
with torch.no_grad():
|
||||
for parameter in ema_hook.src_model.parameters():
|
||||
parameter.data.copy_(torch.randn(parameter.shape))
|
||||
|
||||
src_model = copy.deepcopy(runner.model)
|
||||
ema_model = copy.deepcopy(ema_hook.ema_model)
|
||||
|
||||
func = getattr(ema_hook, func_name)
|
||||
func(runner, *args, **kwargs)
|
||||
|
||||
swapped_src = ema_hook.src_model
|
||||
swapped_ema = ema_hook.ema_model
|
||||
|
||||
for src, ema, swapped_src, swapped_ema in zip(
|
||||
src_model.parameters(), ema_model.parameters(),
|
||||
swapped_src.parameters(), swapped_ema.parameters()):
|
||||
self.assertTrue((src.data == swapped_ema.data).all())
|
||||
self.assertTrue((ema.data == swapped_src.data).all())
|
||||
|
@ -7,21 +7,37 @@ from torch.optim import SGD
|
||||
|
||||
from mmengine.hooks import RuntimeInfoHook
|
||||
from mmengine.optim import OptimWrapper, OptimWrapperDict
|
||||
from mmengine.registry import DATASETS
|
||||
from mmengine.testing import RunnerTestCase
|
||||
|
||||
|
||||
class TestRuntimeInfoHook(RunnerTestCase):
|
||||
|
||||
def test_before_train(self):
|
||||
|
||||
class DatasetWithoutMetainfo:
|
||||
...
|
||||
|
||||
def __len__(self):
|
||||
return 12
|
||||
|
||||
|
||||
class DatasetWithMetainfo(DatasetWithoutMetainfo):
|
||||
metainfo: dict = dict()
|
||||
|
||||
|
||||
class TestRuntimeInfoHook(RunnerTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
DATASETS.register_module(module=DatasetWithoutMetainfo, force=True)
|
||||
DATASETS.register_module(module=DatasetWithMetainfo, force=True)
|
||||
return super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
DATASETS.module_dict.pop('DatasetWithoutMetainfo')
|
||||
DATASETS.module_dict.pop('DatasetWithMetainfo')
|
||||
return super().tearDown()
|
||||
|
||||
def test_before_train(self):
|
||||
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.train_dataloader.dataset.type = DatasetWithoutMetainfo
|
||||
cfg.train_dataloader.dataset.type = 'DatasetWithoutMetainfo'
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.before_train(runner)
|
||||
@ -33,10 +49,7 @@ class TestRuntimeInfoHook(RunnerTestCase):
|
||||
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
|
||||
runner.message_hub.get_info('dataset_meta')
|
||||
|
||||
class DatasetWithMetainfo(DatasetWithoutMetainfo):
|
||||
metainfo = dict()
|
||||
|
||||
cfg.train_dataloader.dataset.type = DatasetWithMetainfo
|
||||
cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
|
||||
runner = self.build_runner(cfg)
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(runner.message_hub.get_info('dataset_meta'), dict())
|
||||
|
Loading…
x
Reference in New Issue
Block a user