1
0
mirror of https://github.com/open-mmlab/mmengine.git synced 2025-06-03 06:19:41 +08:00
mmengine/tests/test_hook/test_iter_timer_hook.py
Mashiro e2a2b0438e
[Refactor] Refine LoggerHook ()
* rename global accessible and intergration get_sintance and create_instance

* move ManagerMixin to utils

* fix as docstring and seporate get_instance to get_instance and get_current_instance

* fix lint

* fix docstring, rename and move test_global_meta

* rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume

* refine MMLogger timestamp, update unit test

* MMLogger add logger_name arguments

* Fix docstring

* Add LogProcessor and some unit test

* update unit test

* complete LogProcessor unit test

* refine LoggerHook

* solve circle import

* change default logger_name to mmengine

* refactor eta

* Fix docstring comment and unitt test

* Fix with runner

* fix docstring

fix docstring

* fix docstring

* Add by_epoch attribute to LoggerHook and fix docstring

* Please mypy and fix comment

* remove \ in MMLogger

* Fix lint

* roll back pre-commit-hook

* Fix hook unit test

* Fix comments

* remove \t in log and add docstring

* Fix as comment

* should not accept other arguments if corresponding instance has been created

* fix logging ddp file saving

* fix logging ddp file saving

* move log processor to logging

* move log processor to logging

* remove current datalaoder

* fix docstring

* fix unit test

* add learing rate in messagehub

* Support output training/validation/testing message after iterations/epochs

* fix docstring

* Fix IterBasedRunner log string

* Fix IterBasedRunner log string

* Support parse validation loss in log processor
2022-04-24 19:23:28 +08:00

71 lines
2.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, Mock, patch
from mmengine.hooks import IterTimerHook
from mmengine.logging import MessageHub
def time_patch():
if not hasattr(time_patch, 'time'):
time_patch.time = 0
else:
time_patch.time += 1
return time_patch.time
class TestIterTimerHook(TestCase):
def setUp(self) -> None:
self.hook = IterTimerHook()
def test_init(self):
assert self.hook.time_sec_tot == 0
assert self.hook.start_iter == 0
def test_before_run(self):
runner = MagicMock()
runner.iter = 1
self.hook.before_run(runner)
assert self.hook.start_iter == 1
def test_before_epoch(self):
runner = Mock()
self.hook._before_epoch(runner)
assert isinstance(self.hook.t, float)
@patch('time.time', MagicMock(return_value=1))
def test_before_iter(self):
runner = MagicMock()
runner.log_buffer = dict()
self.hook._before_epoch(runner)
for mode in ('train', 'val', 'test'):
self.hook._before_iter(runner, batch_idx=1, mode=mode)
runner.message_hub.update_scalar.assert_called_with(
f'{mode}/data_time', 0)
@patch('time.time', time_patch)
def test_after_iter(self):
runner = MagicMock()
runner.log_buffer = dict()
runner.log_processor.window_size = 10
runner.train_loop.max_iters = 100
runner.iter = 0
runner.test_loop.dataloader = [0] * 20
runner.val_loop.dataloader = [0] * 20
self.hook._before_epoch(runner)
self.hook.before_run(runner)
self.hook._after_iter(runner, batch_idx=1)
runner.message_hub.update_scalar.assert_called()
runner.message_hub.get_log.assert_not_called()
runner.message_hub.update_info.assert_not_called()
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
runner.iter = 9
# eta = (100 - 10) / 1
self.hook._after_iter(runner, batch_idx=89)
assert runner.message_hub.get_info('eta') == 90
self.hook._after_iter(runner, batch_idx=9, mode='val')
assert runner.message_hub.get_info('eta') == 10
self.hook._after_iter(runner, batch_idx=19, mode='test')
assert runner.message_hub.get_info('eta') == 0