From 29f399441f5991f80e8ab0f53960cb05b077dafc Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Thu, 22 Dec 2022 10:55:39 +0800 Subject: [PATCH] [Refactor] Use a real runner to test RuntimeInfohook (#810) * Refactor RuntimeInfoHook * Fix as comment --- tests/test_hooks/test_runtime_info_hook.py | 149 ++++++++++----------- 1 file changed, 71 insertions(+), 78 deletions(-) diff --git a/tests/test_hooks/test_runtime_info_hook.py b/tests/test_hooks/test_runtime_info_hook.py index 8b47782d..028707dc 100644 --- a/tests/test_hooks/test_runtime_info_hook.py +++ b/tests/test_hooks/test_runtime_info_hook.py @@ -1,76 +1,68 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase +import copy from unittest.mock import Mock import torch.nn as nn from torch.optim import SGD from mmengine.hooks import RuntimeInfoHook -from mmengine.logging import MessageHub from mmengine.optim import OptimWrapper, OptimWrapperDict +from mmengine.testing import RunnerTestCase -class TestRuntimeInfoHook(TestCase): +class TestRuntimeInfoHook(RunnerTestCase): def test_before_train(self): - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_before_train') - class ToyDataset: + class DatasetWithoutMetainfo: ... - runner = Mock() - runner.epoch = 7 - runner.iter = 71 - runner.max_epochs = 4 - runner.max_iters = 40 - runner.message_hub = message_hub - runner.train_dataloader.dataset = ToyDataset() - hook = RuntimeInfoHook() - hook.before_train(runner) - self.assertEqual(message_hub.get_info('epoch'), 7) - self.assertEqual(message_hub.get_info('iter'), 71) - self.assertEqual(message_hub.get_info('max_epochs'), 4) - self.assertEqual(message_hub.get_info('max_iters'), 40) - with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'): - message_hub.get_info('dataset_meta') + def __len__(self): + return 12 - class ToyDatasetWithMeta: + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.train_dataloader.dataset.type = DatasetWithoutMetainfo + runner = self.build_runner(cfg) + hook = self._get_runtime_info_hook(runner) + hook.before_train(runner) + self.assertEqual(runner.message_hub.get_info('epoch'), 0) + self.assertEqual(runner.message_hub.get_info('iter'), 0) + self.assertEqual(runner.message_hub.get_info('max_epochs'), 2) + self.assertEqual(runner.message_hub.get_info('max_iters'), 8) + + with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'): + runner.message_hub.get_info('dataset_meta') + + class DatasetWithMetainfo(DatasetWithoutMetainfo): metainfo = dict() - runner.train_dataloader.dataset = ToyDatasetWithMeta() + cfg.train_dataloader.dataset.type = DatasetWithMetainfo + runner = self.build_runner(cfg) hook.before_train(runner) - self.assertEqual(message_hub.get_info('dataset_meta'), dict()) + self.assertEqual(runner.message_hub.get_info('dataset_meta'), dict()) def test_before_train_epoch(self): - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_before_train_epoch') - runner = Mock() - runner.epoch = 9 - runner.message_hub = message_hub - hook = RuntimeInfoHook() + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + runner.train_loop._epoch = 9 + hook = self._get_runtime_info_hook(runner) hook.before_train_epoch(runner) - self.assertEqual(message_hub.get_info('epoch'), 9) + self.assertEqual(runner.message_hub.get_info('epoch'), 9) def test_before_train_iter(self): - model = nn.Linear(1, 1) - optim1 = SGD(model.parameters(), lr=0.01) - optim2 = SGD(model.parameters(), lr=0.02) - optim_wrapper1 = OptimWrapper(optim1) - optim_wrapper2 = OptimWrapper(optim2) - optim_wrapper_dict = OptimWrapperDict( - key1=optim_wrapper1, key2=optim_wrapper2) # single optimizer - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_before_train_iter') - runner = Mock() - runner.iter = 9 - runner.optim_wrapper = optim_wrapper1 - runner.message_hub = message_hub - hook = RuntimeInfoHook() + cfg = copy.deepcopy(self.epoch_based_cfg) + lr = cfg.optim_wrapper.optimizer.lr + runner = self.build_runner(cfg) + # set iter + runner.train_loop._iter = 9 + # build optim wrapper + runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper) + hook = self._get_runtime_info_hook(runner) hook.before_train_iter(runner, batch_idx=2, data_batch=None) - self.assertEqual(message_hub.get_info('iter'), 9) - self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01) + self.assertEqual(runner.message_hub.get_info('iter'), 9) + self.assertEqual( + runner.message_hub.get_scalar('train/lr').current(), lr) with self.assertRaisesRegex(AssertionError, 'runner.optim_wrapper.get_lr()'): @@ -79,49 +71,50 @@ class TestRuntimeInfoHook(TestCase): hook.before_train_iter(runner, batch_idx=2, data_batch=None) # multiple optimizers - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_before_train_iter') - runner = Mock() - runner.iter = 9 - optimizer1 = Mock() - optimizer1.param_groups = [{'lr': 0.01}] - optimizer2 = Mock() - optimizer2.param_groups = [{'lr': 0.02}] - runner.message_hub = message_hub + model = nn.ModuleDict( + dict( + layer1=nn.Linear(1, 1), + layer2=nn.Linear(1, 1), + )) + optim1 = SGD(model.layer1.parameters(), lr=0.01) + optim2 = SGD(model.layer2.parameters(), lr=0.02) + optim_wrapper1 = OptimWrapper(optim1) + optim_wrapper2 = OptimWrapper(optim2) + optim_wrapper_dict = OptimWrapperDict( + key1=optim_wrapper1, key2=optim_wrapper2) runner.optim_wrapper = optim_wrapper_dict - hook = RuntimeInfoHook() hook.before_train_iter(runner, batch_idx=2, data_batch=None) - self.assertEqual(message_hub.get_info('iter'), 9) self.assertEqual( - message_hub.get_scalar('train/key1.lr').current(), 0.01) + runner.message_hub.get_scalar('train/key1.lr').current(), 0.01) self.assertEqual( - message_hub.get_scalar('train/key2.lr').current(), 0.02) + runner.message_hub.get_scalar('train/key2.lr').current(), 0.02) def test_after_train_iter(self): - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_after_train_iter') - runner = Mock() - runner.message_hub = message_hub - hook = RuntimeInfoHook() + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + hook = self._get_runtime_info_hook(runner) hook.after_train_iter( runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111}) self.assertEqual( - message_hub.get_scalar('train/loss_cls').current(), 1.111) + runner.message_hub.get_scalar('train/loss_cls').current(), 1.111) def test_after_val_epoch(self): - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_after_val_epoch') - runner = Mock() - runner.message_hub = message_hub - hook = RuntimeInfoHook() + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + hook = self._get_runtime_info_hook(runner) hook.after_val_epoch(runner, metrics={'acc': 0.8}) - self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8) + self.assertEqual( + runner.message_hub.get_scalar('val/acc').current(), 0.8) def test_after_test_epoch(self): - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_after_test_epoch') - runner = Mock() - runner.message_hub = message_hub - hook = RuntimeInfoHook() + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + hook = self._get_runtime_info_hook(runner) hook.after_test_epoch(runner, metrics={'acc': 0.8}) - self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8) + self.assertEqual( + runner.message_hub.get_scalar('test/acc').current(), 0.8) + + def _get_runtime_info_hook(self, runner): + for hook in runner.hooks: + if isinstance(hook, RuntimeInfoHook): + return hook