mmengine/tests/test_hook/test_runtime_info_hook.py
Mashiro f04fec736d
[Feature]: add base model, ddp model wrapper and unit test (#268)
* add base model, ddp model and unit test

* add unit test

* fix unit test

* fix docstring

* fix cpu unit test

* refine base data preprocessor

* refine base data preprocessor

* refine interface of ddp module

* remove optimizer hook

* add forward

* fix as comment

* fix unit test

* fix as comment

* fix build optimizer wrapper

* rebase main and fix unit test

* stack_batch support stacking ndim tensor, add docstring for merge dict

* fix lint

* fix test loop

* make precision_context effective to data_preprocessor

* fix as comment

* fix as comment

* refine docstring

* change collate_data output typehints

* rename to_rgb to bgr_to_rgb and rgb_to_bgr

* support build basemodel with built DataPreprocessor

* fix as comment

* fix docstring
2022-06-07 22:13:53 +08:00

126 lines
4.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import torch.nn as nn
from torch.optim import SGD
from mmengine.hooks import RuntimeInfoHook
from mmengine.logging import MessageHub
from mmengine.optim import OptimWrapper, OptimWrapperDict
class TestRuntimeInfoHook(TestCase):
def test_before_run(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_run')
runner = Mock()
runner.epoch = 3
runner.iter = 30
runner.max_epochs = 4
runner.max_iters = 40
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_run(runner)
self.assertEqual(message_hub.get_info('epoch'), 3)
self.assertEqual(message_hub.get_info('iter'), 30)
self.assertEqual(message_hub.get_info('max_epochs'), 4)
self.assertEqual(message_hub.get_info('max_iters'), 40)
def test_before_train(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train')
runner = Mock()
runner.epoch = 7
runner.iter = 71
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_train(runner)
self.assertEqual(message_hub.get_info('epoch'), 7)
self.assertEqual(message_hub.get_info('iter'), 71)
def test_before_train_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_epoch')
runner = Mock()
runner.epoch = 9
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_train_epoch(runner)
self.assertEqual(message_hub.get_info('epoch'), 9)
def test_before_train_iter(self):
model = nn.Linear(1, 1)
optim1 = SGD(model.parameters(), lr=0.01)
optim2 = SGD(model.parameters(), lr=0.02)
optim_wrapper1 = OptimWrapper(optim1)
optim_wrapper2 = OptimWrapper(optim2)
optim_wrapper_dict = OptimWrapperDict(
key1=optim_wrapper1, key2=optim_wrapper2)
# single optimizer
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
runner = Mock()
runner.iter = 9
runner.optim_wrapper = optim_wrapper1
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
with self.assertRaisesRegex(AssertionError,
'runner.optim_wrapper.get_lr()'):
runner.optim_wrapper = Mock()
runner.optim_wrapper.get_lr = Mock(return_value='error type')
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
# multiple optimizers
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
runner = Mock()
runner.iter = 9
optimizer1 = Mock()
optimizer1.param_groups = [{'lr': 0.01}]
optimizer2 = Mock()
optimizer2.param_groups = [{'lr': 0.02}]
runner.message_hub = message_hub
runner.optim_wrapper = optim_wrapper_dict
hook = RuntimeInfoHook()
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual(
message_hub.get_scalar('train/key1.lr').current(), 0.01)
self.assertEqual(
message_hub.get_scalar('train/key2.lr').current(), 0.02)
def test_after_train_iter(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_train_iter')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_train_iter(
runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111})
self.assertEqual(
message_hub.get_scalar('train/loss_cls').current(), 1.111)
def test_after_val_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_val_epoch')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_val_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8)
def test_after_test_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_test_epoch')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_test_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8)