From b14c179fad6bc8ff4867692fbd79374f90c53000 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 21 Feb 2023 20:45:11 +0800 Subject: [PATCH] [Refactor] Refactor ema hook (#804) * Refacot ema hook unit test * Refacot ema hook unit test * Enhance test_after_load_checkpoint * refine error messsage * Refine as comment --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Fix unit test --- mmengine/hooks/ema_hook.py | 16 +- tests/test_hooks/test_ema_hook.py | 464 +++++++++++---------- tests/test_hooks/test_runtime_info_hook.py | 35 +- 3 files changed, 272 insertions(+), 243 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index bad7d8f8..5bc1051d 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -50,9 +50,11 @@ class EMAHook(Hook): assert not (begin_iter != 0 and begin_epoch != 0), ( '`begin_iter` and `begin_epoch` should not be both set.') assert begin_iter >= 0, ( - f'begin_iter must larger than 0, but got begin: {begin_iter}') + '`begin_iter` must larger than or equal to 0, ' + f'but got begin_iter: {begin_iter}') assert begin_epoch >= 0, ( - f'begin_epoch must larger than 0, but got begin: {begin_epoch}') + '`begin_epoch` must larger than or equal to 0, ' + f'but got begin_epoch: {begin_epoch}') self.begin_iter = begin_iter self.begin_epoch = begin_epoch # If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be @@ -80,12 +82,14 @@ class EMAHook(Hook): """ if self.enabled_by_epoch: assert self.begin_epoch <= runner.max_epochs, ( - 'self.begin_epoch should be smaller than runner.max_epochs: ' - f'{runner.max_epochs}, but got begin: {self.begin_epoch}') + 'self.begin_epoch should be smaller than or equal to ' + f'runner.max_epochs: {runner.max_epochs}, but got ' + f'begin_epoch: {self.begin_epoch}') else: assert self.begin_iter <= runner.max_iters, ( - 'self.begin_iter should be smaller than runner.max_iters: ' - f'{runner.max_iters}, but got begin: {self.begin_iter}') + 'self.begin_iter should be smaller than or equal to ' + f'runner.max_iters: {runner.max_iters}, but got ' + f'begin_iter: {self.begin_iter}') def after_train_iter(self, runner, diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index dd7fc91a..4ceebe90 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -1,58 +1,35 @@ # Copyright (c) OpenMMLab. All rights reserved. -import logging +import copy import os.path as osp -import tempfile -from unittest import TestCase -from unittest.mock import Mock import torch import torch.nn as nn -from torch.utils.data import Dataset -from mmengine.evaluator import Evaluator +from mmengine.config import ConfigDict from mmengine.hooks import EMAHook -from mmengine.logging import MMLogger from mmengine.model import BaseModel, ExponentialMovingAverage -from mmengine.optim import OptimWrapper -from mmengine.registry import DATASETS, MODEL_WRAPPERS -from mmengine.runner import Runner -from mmengine.testing import assert_allclose +from mmengine.registry import MODELS +from mmengine.testing import RunnerTestCase, assert_allclose +from mmengine.testing.runner_test_case import ToyModel -class ToyModel(BaseModel): +class DummyWrapper(BaseModel): - def __init__(self): - super().__init__() - self.linear = nn.Linear(2, 1) - - def forward(self, inputs, data_sample, mode='tensor'): - labels = torch.stack(data_sample) - inputs = torch.stack(inputs) - outputs = self.linear(inputs) - if mode == 'tensor': - return outputs - elif mode == 'loss': - loss = (labels - outputs).sum() - outputs = dict(loss=loss) - return outputs - else: - return outputs - - -class ToyModel1(ToyModel): - - def __init__(self): + def __init__(self, model): super().__init__() + if not isinstance(model, nn.Module): + model = MODELS.build(model) + self.module = model def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) + return self.module(*args, **kwargs) class ToyModel2(ToyModel): def __init__(self): super().__init__() - self.linear1 = nn.Linear(2, 1) + self.linear3 = nn.Linear(2, 1) def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) @@ -62,239 +39,247 @@ class ToyModel3(ToyModel): def __init__(self): super().__init__() - self.linear = nn.Linear(2, 2) + self.linear2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 1)) def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -@DATASETS.register_module() -class DummyDataset(Dataset): - METAINFO = dict() # type: ignore - data = torch.randn(12, 2) - label = torch.ones(12) +class TestEMAHook(RunnerTestCase): - @property - def metainfo(self): - return self.METAINFO - - def __len__(self): - return self.data.size(0) - - def __getitem__(self, index): - return dict(inputs=self.data[index], data_sample=self.label[index]) - - -class TestEMAHook(TestCase): - - def setUp(self): - self.temp_dir = tempfile.TemporaryDirectory() + def setUp(self) -> None: + MODELS.register_module(name='DummyWrapper', module=DummyWrapper) + MODELS.register_module(name='ToyModel2', module=ToyModel2) + MODELS.register_module(name='ToyModel3', module=ToyModel3) + return super().setUp() def tearDown(self): - # `FileHandler` should be closed in Windows, otherwise we cannot - # delete the temporary directory - logging.shutdown() - MMLogger._instance_dict.clear() - self.temp_dir.cleanup() + MODELS.module_dict.pop('DummyWrapper') + MODELS.module_dict.pop('ToyModel2') + MODELS.module_dict.pop('ToyModel3') + return super().tearDown() - def test_ema_hook(self): - device = 'cuda:0' if torch.cuda.is_available() else 'cpu' - model = ToyModel1().to(device) - evaluator = Evaluator([]) - evaluator.evaluate = Mock(return_value=dict(acc=0.5)) - runner = Runner( - model=model, - train_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - val_evaluator=evaluator, - work_dir=self.temp_dir.name, - optim_wrapper=OptimWrapper( - torch.optim.Adam(ToyModel().parameters())), - train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), - val_cfg=dict(), - default_hooks=dict(logger=None), - custom_hooks=[dict(type='EMAHook', )], - experiment_name='test1') - runner.train() + def test_init(self): + EMAHook() + + with self.assertRaisesRegex(AssertionError, '`begin_iter` must'): + EMAHook(begin_iter=-1) + + with self.assertRaisesRegex(AssertionError, '`begin_epoch` must'): + EMAHook(begin_epoch=-1) + + with self.assertRaisesRegex(AssertionError, + '`begin_iter` and `begin_epoch`'): + EMAHook(begin_iter=1, begin_epoch=1) + + def _get_ema_hook(self, runner): for hook in runner.hooks: if isinstance(hook, EMAHook): - self.assertTrue( - isinstance(hook.ema_model, ExponentialMovingAverage)) + return hook + def test_before_run(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.custom_hooks = [dict(type='EMAHook')] + runner = self.build_runner(cfg) + ema_hook = self._get_ema_hook(runner) + ema_hook.before_run(runner) + self.assertIsInstance(ema_hook.ema_model, ExponentialMovingAverage) + self.assertIs(ema_hook.src_model, runner.model) + + def test_before_train(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.custom_hooks = [ + dict(type='EMAHook', begin_epoch=cfg.train_cfg.max_epochs - 1) + ] + runner = self.build_runner(cfg) + ema_hook = self._get_ema_hook(runner) + ema_hook.before_train(runner) + + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.custom_hooks = [ + dict(type='EMAHook', begin_epoch=cfg.train_cfg.max_epochs + 1) + ] + runner = self.build_runner(cfg) + ema_hook = self._get_ema_hook(runner) + + with self.assertRaisesRegex(AssertionError, 'self.begin_epoch'): + ema_hook.before_train(runner) + + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.custom_hooks = [ + dict(type='EMAHook', begin_iter=cfg.train_cfg.max_iters + 1) + ] + runner = self.build_runner(cfg) + ema_hook = self._get_ema_hook(runner) + + with self.assertRaisesRegex(AssertionError, 'self.begin_iter'): + ema_hook.before_train(runner) + + def test_after_train_iter(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.custom_hooks = [dict(type='EMAHook')] + runner = self.build_runner(cfg) + ema_hook = self._get_ema_hook(runner) + + ema_hook = self._get_ema_hook(runner) + ema_hook.before_run(runner) + ema_hook.before_train(runner) + + src_model = runner.model + ema_model = ema_hook.ema_model + + with torch.no_grad(): + for parameter in src_model.parameters(): + parameter.data.copy_(torch.randn(parameter.shape)) + + ema_hook.after_train_iter(runner, 1) + for src, ema in zip(src_model.parameters(), ema_model.parameters()): + assert_allclose(src.data, ema.data) + + with torch.no_grad(): + for parameter in src_model.parameters(): + parameter.data.copy_(torch.randn(parameter.shape)) + + ema_hook.after_train_iter(runner, 1) + + for src, ema in zip(src_model.parameters(), ema_model.parameters()): + self.assertFalse((src.data == ema.data).all()) + + def test_before_val_epoch(self): + self._test_swap_parameters('before_val_epoch') + + def test_after_val_epoch(self): + self._test_swap_parameters('after_val_epoch') + + def test_before_test_epoch(self): + self._test_swap_parameters('before_test_epoch') + + def test_after_test_epoch(self): + self._test_swap_parameters('after_test_epoch') + + def test_before_save_checkpoint(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + checkpoint = dict(state_dict=ToyModel().state_dict()) + ema_hook = EMAHook() + ema_hook.before_run(runner) + ema_hook.before_train(runner) + + ori_checkpoint = copy.deepcopy(checkpoint) + ema_hook.before_save_checkpoint(runner, checkpoint) + + for key in ori_checkpoint['state_dict'].keys(): + assert_allclose( + ori_checkpoint['state_dict'][key].cpu(), + checkpoint['ema_state_dict'][f'module.{key}'].cpu()) + + assert_allclose( + ema_hook.ema_model.state_dict()[f'module.{key}'].cpu(), + checkpoint['state_dict'][key].cpu()) + + def test_after_load_checkpoint(self): + # Test load a checkpoint without ema_state_dict. + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + checkpoint = dict(state_dict=ToyModel().state_dict()) + ema_hook = EMAHook() + ema_hook.before_run(runner) + ema_hook.before_train(runner) + ema_hook.after_load_checkpoint(runner, checkpoint) + + for key in checkpoint['state_dict'].keys(): + assert_allclose( + checkpoint['state_dict'][key].cpu(), + ema_hook.ema_model.state_dict()[f'module.{key}'].cpu()) + + # Test a warning should be raised when resuming from a checkpoint + # without `ema_state_dict` + runner._resume = True + ema_hook.after_load_checkpoint(runner, checkpoint) + with self.assertLogs(runner.logger, level='WARNING') as cm: + ema_hook.after_load_checkpoint(runner, checkpoint) + self.assertRegex(cm.records[0].msg, 'There is no `ema_state_dict`') + + # Check the weight of state_dict and ema_state_dict have been swapped. + # when runner._resume is True + runner._resume = True + checkpoint = dict( + state_dict=ToyModel().state_dict(), + ema_state_dict=ExponentialMovingAverage(ToyModel()).state_dict()) + ori_checkpoint = copy.deepcopy(checkpoint) + ema_hook.after_load_checkpoint(runner, checkpoint) + for key in ori_checkpoint['state_dict'].keys(): + assert_allclose( + ori_checkpoint['state_dict'][key].cpu(), + ema_hook.ema_model.state_dict()[f'module.{key}'].cpu()) + + runner._resume = False + ema_hook.after_load_checkpoint(runner, checkpoint) + + def test_with_runner(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.custom_hooks = [ConfigDict(type='EMAHook')] + runner = self.build_runner(cfg) + ema_hook = self._get_ema_hook(runner) + runner.train() self.assertTrue( - osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth'))) + isinstance(ema_hook.ema_model, ExponentialMovingAverage)) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) self.assertTrue('ema_state_dict' in checkpoint) self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) # load and testing - runner = Runner( - model=model, - test_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - test_evaluator=evaluator, - test_cfg=dict(), - work_dir=self.temp_dir.name, - load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), - default_hooks=dict(logger=None), - custom_hooks=[dict(type='EMAHook')], - experiment_name='test2') + cfg.load_from = osp.join(self.temp_dir.name, 'epoch_2.pth') + runner = self.build_runner(cfg) runner.test() - @MODEL_WRAPPERS.register_module() - class DummyWrapper(BaseModel): - - def __init__(self, model): - super().__init__() - self.module = model - - def forward(self, *args, **kwargs): - return self.module(*args, **kwargs) - # with model wrapper - runner = Runner( - model=DummyWrapper(ToyModel()), - test_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - test_evaluator=evaluator, - test_cfg=dict(), - work_dir=self.temp_dir.name, - load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), - default_hooks=dict(logger=None), - custom_hooks=[dict(type='EMAHook')], - experiment_name='test3') + cfg.model = ConfigDict(type='DummyWrapper', model=cfg.model) + runner = self.build_runner(cfg) runner.test() # Test load checkpoint without ema_state_dict - ckpt = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) - ckpt.pop('ema_state_dict') - torch.save(ckpt, + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint.pop('ema_state_dict') + torch.save(checkpoint, osp.join(self.temp_dir.name, 'without_ema_state_dict.pth')) - runner = Runner( - model=DummyWrapper(ToyModel()), - test_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - test_evaluator=evaluator, - test_cfg=dict(), - work_dir=self.temp_dir.name, - load_from=osp.join(self.temp_dir.name, - 'without_ema_state_dict.pth'), - default_hooks=dict(logger=None), - custom_hooks=[dict(type='EMAHook')], - experiment_name='test4') + + cfg.load_from = osp.join(self.temp_dir.name, + 'without_ema_state_dict.pth') + runner = self.build_runner(cfg) runner.test() - # Test does not load ckpt strict_loadly. + # Test does not load checkpoint strictly (different name). # Test load checkpoint without ema_state_dict - runner = Runner( - model=ToyModel2(), - test_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - test_evaluator=evaluator, - test_cfg=dict(), - work_dir=self.temp_dir.name, - load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), - default_hooks=dict(logger=None), - custom_hooks=[dict(type='EMAHook', strict_load=False)], - experiment_name='test5') + cfg.model = ConfigDict(type='ToyModel2') + cfg.custom_hooks = [ConfigDict(type='EMAHook', strict_load=False)] + runner = self.build_runner(cfg) runner.test() - # Test does not load ckpt strict_loadly. + # Test does not load ckpt strictly (different weight size). # Test load checkpoint without ema_state_dict - # Test with different size head. - runner = Runner( - model=ToyModel3(), - test_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - test_evaluator=evaluator, - test_cfg=dict(), - work_dir=self.temp_dir.name, - load_from=osp.join(self.temp_dir.name, - 'without_ema_state_dict.pth'), - default_hooks=dict(logger=None), - custom_hooks=[dict(type='EMAHook', strict_load=False)], - experiment_name='test5.1') + cfg.model = ConfigDict(type='ToyModel3') + runner = self.build_runner(cfg) runner.test() # Test enable ema at 5 epochs. - runner = Runner( - model=model, - train_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - val_evaluator=evaluator, - work_dir=self.temp_dir.name, - optim_wrapper=OptimWrapper( - torch.optim.Adam(ToyModel().parameters())), - train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1), - val_cfg=dict(), - default_hooks=dict(logger=None), - custom_hooks=[dict(type='EMAHook', begin_epoch=5)], - experiment_name='test6') + cfg.train_cfg.max_epochs = 10 + cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_epoch=5)] + runner = self.build_runner(cfg) runner.train() state_dict = torch.load( osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu') self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) - state_dict = torch.load( - osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu') - self.assertIn('ema_state_dict', state_dict) # Test enable ema at 5 iterations. - runner = Runner( - model=model, - train_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='DummyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - val_evaluator=evaluator, - work_dir=self.temp_dir.name, - optim_wrapper=OptimWrapper( - torch.optim.Adam(ToyModel().parameters())), - train_cfg=dict(by_epoch=False, max_iters=10, val_interval=1), - val_cfg=dict(), - default_hooks=dict( - checkpoint=dict( - type='CheckpointHook', interval=1, by_epoch=False)), - custom_hooks=[dict(type='EMAHook', begin_iter=5)], - experiment_name='test7') + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.train_cfg.val_interval = 1 + cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_iter=5)] + cfg.default_hooks.checkpoint.interval = 1 + runner = self.build_runner(cfg) runner.train() state_dict = torch.load( osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu') @@ -304,3 +289,30 @@ class TestEMAHook(TestCase): state_dict = torch.load( osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu') self.assertIn('ema_state_dict', state_dict) + + def _test_swap_parameters(self, func_name, *args, **kwargs): + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.custom_hooks = [dict(type='EMAHook')] + runner = self.build_runner(cfg) + ema_hook = self._get_ema_hook(runner) + + runner.train() + + with torch.no_grad(): + for parameter in ema_hook.src_model.parameters(): + parameter.data.copy_(torch.randn(parameter.shape)) + + src_model = copy.deepcopy(runner.model) + ema_model = copy.deepcopy(ema_hook.ema_model) + + func = getattr(ema_hook, func_name) + func(runner, *args, **kwargs) + + swapped_src = ema_hook.src_model + swapped_ema = ema_hook.ema_model + + for src, ema, swapped_src, swapped_ema in zip( + src_model.parameters(), ema_model.parameters(), + swapped_src.parameters(), swapped_ema.parameters()): + self.assertTrue((src.data == swapped_ema.data).all()) + self.assertTrue((ema.data == swapped_src.data).all()) diff --git a/tests/test_hooks/test_runtime_info_hook.py b/tests/test_hooks/test_runtime_info_hook.py index 028707dc..7593f845 100644 --- a/tests/test_hooks/test_runtime_info_hook.py +++ b/tests/test_hooks/test_runtime_info_hook.py @@ -7,21 +7,37 @@ from torch.optim import SGD from mmengine.hooks import RuntimeInfoHook from mmengine.optim import OptimWrapper, OptimWrapperDict +from mmengine.registry import DATASETS from mmengine.testing import RunnerTestCase +class DatasetWithoutMetainfo: + ... + + def __len__(self): + return 12 + + +class DatasetWithMetainfo(DatasetWithoutMetainfo): + metainfo: dict = dict() + + class TestRuntimeInfoHook(RunnerTestCase): + def setUp(self) -> None: + DATASETS.register_module(module=DatasetWithoutMetainfo, force=True) + DATASETS.register_module(module=DatasetWithMetainfo, force=True) + return super().setUp() + + def tearDown(self): + DATASETS.module_dict.pop('DatasetWithoutMetainfo') + DATASETS.module_dict.pop('DatasetWithMetainfo') + return super().tearDown() + def test_before_train(self): - class DatasetWithoutMetainfo: - ... - - def __len__(self): - return 12 - cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.train_dataloader.dataset.type = DatasetWithoutMetainfo + cfg.train_dataloader.dataset.type = 'DatasetWithoutMetainfo' runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) hook.before_train(runner) @@ -33,10 +49,7 @@ class TestRuntimeInfoHook(RunnerTestCase): with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'): runner.message_hub.get_info('dataset_meta') - class DatasetWithMetainfo(DatasetWithoutMetainfo): - metainfo = dict() - - cfg.train_dataloader.dataset.type = DatasetWithMetainfo + cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo' runner = self.build_runner(cfg) hook.before_train(runner) self.assertEqual(runner.message_hub.get_info('dataset_meta'), dict())