2022-03-01 12:02:34 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2022-04-26 00:37:16 +08:00
|
|
|
from unittest import TestCase
|
|
|
|
from unittest.mock import MagicMock, Mock, patch
|
2022-03-01 12:02:34 +08:00
|
|
|
|
|
|
|
from mmengine.hooks import IterTimerHook
|
2022-04-26 00:37:16 +08:00
|
|
|
from mmengine.logging import MessageHub
|
2022-03-01 12:02:34 +08:00
|
|
|
|
|
|
|
|
2022-04-26 00:37:16 +08:00
|
|
|
def time_patch():
|
|
|
|
if not hasattr(time_patch, 'time'):
|
|
|
|
time_patch.time = 0
|
|
|
|
else:
|
|
|
|
time_patch.time += 1
|
|
|
|
return time_patch.time
|
|
|
|
|
|
|
|
|
|
|
|
class TestIterTimerHook(TestCase):
|
|
|
|
|
|
|
|
def setUp(self) -> None:
|
|
|
|
self.hook = IterTimerHook()
|
|
|
|
|
|
|
|
def test_init(self):
|
|
|
|
assert self.hook.time_sec_tot == 0
|
|
|
|
assert self.hook.start_iter == 0
|
|
|
|
|
|
|
|
def test_before_run(self):
|
|
|
|
runner = MagicMock()
|
|
|
|
runner.iter = 1
|
|
|
|
self.hook.before_run(runner)
|
|
|
|
assert self.hook.start_iter == 1
|
2022-03-01 12:02:34 +08:00
|
|
|
|
|
|
|
def test_before_epoch(self):
|
2022-03-29 11:40:38 +08:00
|
|
|
runner = Mock()
|
2022-04-26 00:37:16 +08:00
|
|
|
self.hook._before_epoch(runner)
|
|
|
|
assert isinstance(self.hook.t, float)
|
2022-03-01 12:02:34 +08:00
|
|
|
|
2022-04-26 00:37:16 +08:00
|
|
|
@patch('time.time', MagicMock(return_value=1))
|
2022-03-01 12:02:34 +08:00
|
|
|
def test_before_iter(self):
|
2022-04-26 00:37:16 +08:00
|
|
|
runner = MagicMock()
|
2022-03-29 11:40:38 +08:00
|
|
|
runner.log_buffer = dict()
|
2022-04-26 00:37:16 +08:00
|
|
|
self.hook._before_epoch(runner)
|
|
|
|
for mode in ('train', 'val', 'test'):
|
|
|
|
self.hook._before_iter(runner, batch_idx=1, mode=mode)
|
|
|
|
runner.message_hub.update_scalar.assert_called_with(
|
|
|
|
f'{mode}/data_time', 0)
|
2022-03-01 12:02:34 +08:00
|
|
|
|
2022-04-26 00:37:16 +08:00
|
|
|
@patch('time.time', time_patch)
|
2022-03-01 12:02:34 +08:00
|
|
|
def test_after_iter(self):
|
2022-04-26 00:37:16 +08:00
|
|
|
runner = MagicMock()
|
2022-03-29 11:40:38 +08:00
|
|
|
runner.log_buffer = dict()
|
2022-04-26 00:37:16 +08:00
|
|
|
runner.log_processor.window_size = 10
|
|
|
|
runner.train_loop.max_iters = 100
|
|
|
|
runner.iter = 0
|
|
|
|
runner.test_loop.dataloader = [0] * 20
|
|
|
|
runner.val_loop.dataloader = [0] * 20
|
|
|
|
self.hook._before_epoch(runner)
|
|
|
|
self.hook.before_run(runner)
|
|
|
|
self.hook._after_iter(runner, batch_idx=1)
|
2022-04-21 19:12:10 +08:00
|
|
|
runner.message_hub.update_scalar.assert_called()
|
2022-04-26 00:37:16 +08:00
|
|
|
runner.message_hub.get_log.assert_not_called()
|
|
|
|
runner.message_hub.update_info.assert_not_called()
|
|
|
|
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
|
|
|
|
runner.iter = 9
|
|
|
|
# eta = (100 - 10) / 1
|
|
|
|
self.hook._after_iter(runner, batch_idx=89)
|
|
|
|
assert runner.message_hub.get_info('eta') == 90
|
|
|
|
self.hook._after_iter(runner, batch_idx=9, mode='val')
|
|
|
|
assert runner.message_hub.get_info('eta') == 10
|
|
|
|
self.hook._after_iter(runner, batch_idx=19, mode='test')
|
|
|
|
assert runner.message_hub.get_info('eta') == 0
|