[Refactor] Use a real runner to test RuntimeInfohook (#810)

* Refactor RuntimeInfoHook

* Fix as comment
This commit is contained in:
Mashiro 2022-12-22 10:55:39 +08:00 committed by Zaida Zhou
parent c4efda4186
commit 29f399441f

View File

@ -1,76 +1,68 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase import copy
from unittest.mock import Mock from unittest.mock import Mock
import torch.nn as nn import torch.nn as nn
from torch.optim import SGD from torch.optim import SGD
from mmengine.hooks import RuntimeInfoHook from mmengine.hooks import RuntimeInfoHook
from mmengine.logging import MessageHub
from mmengine.optim import OptimWrapper, OptimWrapperDict from mmengine.optim import OptimWrapper, OptimWrapperDict
from mmengine.testing import RunnerTestCase
class TestRuntimeInfoHook(TestCase): class TestRuntimeInfoHook(RunnerTestCase):
def test_before_train(self): def test_before_train(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train')
class ToyDataset: class DatasetWithoutMetainfo:
... ...
runner = Mock() def __len__(self):
runner.epoch = 7 return 12
runner.iter = 71
runner.max_epochs = 4
runner.max_iters = 40
runner.message_hub = message_hub
runner.train_dataloader.dataset = ToyDataset()
hook = RuntimeInfoHook()
hook.before_train(runner)
self.assertEqual(message_hub.get_info('epoch'), 7)
self.assertEqual(message_hub.get_info('iter'), 71)
self.assertEqual(message_hub.get_info('max_epochs'), 4)
self.assertEqual(message_hub.get_info('max_iters'), 40)
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
message_hub.get_info('dataset_meta')
class ToyDatasetWithMeta: cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.train_dataloader.dataset.type = DatasetWithoutMetainfo
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(runner.message_hub.get_info('epoch'), 0)
self.assertEqual(runner.message_hub.get_info('iter'), 0)
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
runner.message_hub.get_info('dataset_meta')
class DatasetWithMetainfo(DatasetWithoutMetainfo):
metainfo = dict() metainfo = dict()
runner.train_dataloader.dataset = ToyDatasetWithMeta() cfg.train_dataloader.dataset.type = DatasetWithMetainfo
runner = self.build_runner(cfg)
hook.before_train(runner) hook.before_train(runner)
self.assertEqual(message_hub.get_info('dataset_meta'), dict()) self.assertEqual(runner.message_hub.get_info('dataset_meta'), dict())
def test_before_train_epoch(self): def test_before_train_epoch(self):
message_hub = MessageHub.get_instance( cfg = copy.deepcopy(self.epoch_based_cfg)
'runtime_info_hook_test_before_train_epoch') runner = self.build_runner(cfg)
runner = Mock() runner.train_loop._epoch = 9
runner.epoch = 9 hook = self._get_runtime_info_hook(runner)
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_train_epoch(runner) hook.before_train_epoch(runner)
self.assertEqual(message_hub.get_info('epoch'), 9) self.assertEqual(runner.message_hub.get_info('epoch'), 9)
def test_before_train_iter(self): def test_before_train_iter(self):
model = nn.Linear(1, 1)
optim1 = SGD(model.parameters(), lr=0.01)
optim2 = SGD(model.parameters(), lr=0.02)
optim_wrapper1 = OptimWrapper(optim1)
optim_wrapper2 = OptimWrapper(optim2)
optim_wrapper_dict = OptimWrapperDict(
key1=optim_wrapper1, key2=optim_wrapper2)
# single optimizer # single optimizer
message_hub = MessageHub.get_instance( cfg = copy.deepcopy(self.epoch_based_cfg)
'runtime_info_hook_test_before_train_iter') lr = cfg.optim_wrapper.optimizer.lr
runner = Mock() runner = self.build_runner(cfg)
runner.iter = 9 # set iter
runner.optim_wrapper = optim_wrapper1 runner.train_loop._iter = 9
runner.message_hub = message_hub # build optim wrapper
hook = RuntimeInfoHook() runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper)
hook = self._get_runtime_info_hook(runner)
hook.before_train_iter(runner, batch_idx=2, data_batch=None) hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9) self.assertEqual(runner.message_hub.get_info('iter'), 9)
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01) self.assertEqual(
runner.message_hub.get_scalar('train/lr').current(), lr)
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(AssertionError,
'runner.optim_wrapper.get_lr()'): 'runner.optim_wrapper.get_lr()'):
@ -79,49 +71,50 @@ class TestRuntimeInfoHook(TestCase):
hook.before_train_iter(runner, batch_idx=2, data_batch=None) hook.before_train_iter(runner, batch_idx=2, data_batch=None)
# multiple optimizers # multiple optimizers
message_hub = MessageHub.get_instance( model = nn.ModuleDict(
'runtime_info_hook_test_before_train_iter') dict(
runner = Mock() layer1=nn.Linear(1, 1),
runner.iter = 9 layer2=nn.Linear(1, 1),
optimizer1 = Mock() ))
optimizer1.param_groups = [{'lr': 0.01}] optim1 = SGD(model.layer1.parameters(), lr=0.01)
optimizer2 = Mock() optim2 = SGD(model.layer2.parameters(), lr=0.02)
optimizer2.param_groups = [{'lr': 0.02}] optim_wrapper1 = OptimWrapper(optim1)
runner.message_hub = message_hub optim_wrapper2 = OptimWrapper(optim2)
optim_wrapper_dict = OptimWrapperDict(
key1=optim_wrapper1, key2=optim_wrapper2)
runner.optim_wrapper = optim_wrapper_dict runner.optim_wrapper = optim_wrapper_dict
hook = RuntimeInfoHook()
hook.before_train_iter(runner, batch_idx=2, data_batch=None) hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual( self.assertEqual(
message_hub.get_scalar('train/key1.lr').current(), 0.01) runner.message_hub.get_scalar('train/key1.lr').current(), 0.01)
self.assertEqual( self.assertEqual(
message_hub.get_scalar('train/key2.lr').current(), 0.02) runner.message_hub.get_scalar('train/key2.lr').current(), 0.02)
def test_after_train_iter(self): def test_after_train_iter(self):
message_hub = MessageHub.get_instance( cfg = copy.deepcopy(self.epoch_based_cfg)
'runtime_info_hook_test_after_train_iter') runner = self.build_runner(cfg)
runner = Mock() hook = self._get_runtime_info_hook(runner)
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_train_iter( hook.after_train_iter(
runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111}) runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111})
self.assertEqual( self.assertEqual(
message_hub.get_scalar('train/loss_cls').current(), 1.111) runner.message_hub.get_scalar('train/loss_cls').current(), 1.111)
def test_after_val_epoch(self): def test_after_val_epoch(self):
message_hub = MessageHub.get_instance( cfg = copy.deepcopy(self.epoch_based_cfg)
'runtime_info_hook_test_after_val_epoch') runner = self.build_runner(cfg)
runner = Mock() hook = self._get_runtime_info_hook(runner)
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_val_epoch(runner, metrics={'acc': 0.8}) hook.after_val_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8) self.assertEqual(
runner.message_hub.get_scalar('val/acc').current(), 0.8)
def test_after_test_epoch(self): def test_after_test_epoch(self):
message_hub = MessageHub.get_instance( cfg = copy.deepcopy(self.epoch_based_cfg)
'runtime_info_hook_test_after_test_epoch') runner = self.build_runner(cfg)
runner = Mock() hook = self._get_runtime_info_hook(runner)
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_test_epoch(runner, metrics={'acc': 0.8}) hook.after_test_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8) self.assertEqual(
runner.message_hub.get_scalar('test/acc').current(), 0.8)
def _get_runtime_info_hook(self, runner):
for hook in runner.hooks:
if isinstance(hook, RuntimeInfoHook):
return hook