mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Use a real runner to test RuntimeInfohook (#810)
* Refactor RuntimeInfoHook * Fix as comment
This commit is contained in:
parent
c4efda4186
commit
29f399441f
@ -1,76 +1,68 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
import copy
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.hooks import RuntimeInfoHook
|
||||
from mmengine.logging import MessageHub
|
||||
from mmengine.optim import OptimWrapper, OptimWrapperDict
|
||||
from mmengine.testing import RunnerTestCase
|
||||
|
||||
|
||||
class TestRuntimeInfoHook(TestCase):
|
||||
class TestRuntimeInfoHook(RunnerTestCase):
|
||||
|
||||
def test_before_train(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train')
|
||||
|
||||
class ToyDataset:
|
||||
class DatasetWithoutMetainfo:
|
||||
...
|
||||
|
||||
runner = Mock()
|
||||
runner.epoch = 7
|
||||
runner.iter = 71
|
||||
runner.max_epochs = 4
|
||||
runner.max_iters = 40
|
||||
runner.message_hub = message_hub
|
||||
runner.train_dataloader.dataset = ToyDataset()
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(message_hub.get_info('epoch'), 7)
|
||||
self.assertEqual(message_hub.get_info('iter'), 71)
|
||||
self.assertEqual(message_hub.get_info('max_epochs'), 4)
|
||||
self.assertEqual(message_hub.get_info('max_iters'), 40)
|
||||
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
|
||||
message_hub.get_info('dataset_meta')
|
||||
def __len__(self):
|
||||
return 12
|
||||
|
||||
class ToyDatasetWithMeta:
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.train_dataloader.dataset.type = DatasetWithoutMetainfo
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(runner.message_hub.get_info('epoch'), 0)
|
||||
self.assertEqual(runner.message_hub.get_info('iter'), 0)
|
||||
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
|
||||
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)
|
||||
|
||||
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
|
||||
runner.message_hub.get_info('dataset_meta')
|
||||
|
||||
class DatasetWithMetainfo(DatasetWithoutMetainfo):
|
||||
metainfo = dict()
|
||||
|
||||
runner.train_dataloader.dataset = ToyDatasetWithMeta()
|
||||
cfg.train_dataloader.dataset.type = DatasetWithMetainfo
|
||||
runner = self.build_runner(cfg)
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(message_hub.get_info('dataset_meta'), dict())
|
||||
self.assertEqual(runner.message_hub.get_info('dataset_meta'), dict())
|
||||
|
||||
def test_before_train_epoch(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train_epoch')
|
||||
runner = Mock()
|
||||
runner.epoch = 9
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train_loop._epoch = 9
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.before_train_epoch(runner)
|
||||
self.assertEqual(message_hub.get_info('epoch'), 9)
|
||||
self.assertEqual(runner.message_hub.get_info('epoch'), 9)
|
||||
|
||||
def test_before_train_iter(self):
|
||||
model = nn.Linear(1, 1)
|
||||
optim1 = SGD(model.parameters(), lr=0.01)
|
||||
optim2 = SGD(model.parameters(), lr=0.02)
|
||||
optim_wrapper1 = OptimWrapper(optim1)
|
||||
optim_wrapper2 = OptimWrapper(optim2)
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
key1=optim_wrapper1, key2=optim_wrapper2)
|
||||
# single optimizer
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train_iter')
|
||||
runner = Mock()
|
||||
runner.iter = 9
|
||||
runner.optim_wrapper = optim_wrapper1
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
lr = cfg.optim_wrapper.optimizer.lr
|
||||
runner = self.build_runner(cfg)
|
||||
# set iter
|
||||
runner.train_loop._iter = 9
|
||||
# build optim wrapper
|
||||
runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
|
||||
self.assertEqual(message_hub.get_info('iter'), 9)
|
||||
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
|
||||
self.assertEqual(runner.message_hub.get_info('iter'), 9)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_scalar('train/lr').current(), lr)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'runner.optim_wrapper.get_lr()'):
|
||||
@ -79,49 +71,50 @@ class TestRuntimeInfoHook(TestCase):
|
||||
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
|
||||
|
||||
# multiple optimizers
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train_iter')
|
||||
runner = Mock()
|
||||
runner.iter = 9
|
||||
optimizer1 = Mock()
|
||||
optimizer1.param_groups = [{'lr': 0.01}]
|
||||
optimizer2 = Mock()
|
||||
optimizer2.param_groups = [{'lr': 0.02}]
|
||||
runner.message_hub = message_hub
|
||||
model = nn.ModuleDict(
|
||||
dict(
|
||||
layer1=nn.Linear(1, 1),
|
||||
layer2=nn.Linear(1, 1),
|
||||
))
|
||||
optim1 = SGD(model.layer1.parameters(), lr=0.01)
|
||||
optim2 = SGD(model.layer2.parameters(), lr=0.02)
|
||||
optim_wrapper1 = OptimWrapper(optim1)
|
||||
optim_wrapper2 = OptimWrapper(optim2)
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
key1=optim_wrapper1, key2=optim_wrapper2)
|
||||
runner.optim_wrapper = optim_wrapper_dict
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
|
||||
self.assertEqual(message_hub.get_info('iter'), 9)
|
||||
self.assertEqual(
|
||||
message_hub.get_scalar('train/key1.lr').current(), 0.01)
|
||||
runner.message_hub.get_scalar('train/key1.lr').current(), 0.01)
|
||||
self.assertEqual(
|
||||
message_hub.get_scalar('train/key2.lr').current(), 0.02)
|
||||
runner.message_hub.get_scalar('train/key2.lr').current(), 0.02)
|
||||
|
||||
def test_after_train_iter(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_after_train_iter')
|
||||
runner = Mock()
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.after_train_iter(
|
||||
runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111})
|
||||
self.assertEqual(
|
||||
message_hub.get_scalar('train/loss_cls').current(), 1.111)
|
||||
runner.message_hub.get_scalar('train/loss_cls').current(), 1.111)
|
||||
|
||||
def test_after_val_epoch(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_after_val_epoch')
|
||||
runner = Mock()
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.after_val_epoch(runner, metrics={'acc': 0.8})
|
||||
self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_scalar('val/acc').current(), 0.8)
|
||||
|
||||
def test_after_test_epoch(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_after_test_epoch')
|
||||
runner = Mock()
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.after_test_epoch(runner, metrics={'acc': 0.8})
|
||||
self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_scalar('test/acc').current(), 0.8)
|
||||
|
||||
def _get_runtime_info_hook(self, runner):
|
||||
for hook in runner.hooks:
|
||||
if isinstance(hook, RuntimeInfoHook):
|
||||
return hook
|
||||
|
Loading…
x
Reference in New Issue
Block a user