# Copyright (c) OpenMMLab. All rights reserved. import copy from unittest.mock import patch from mmengine.hooks import IterTimerHook from mmengine.testing import RunnerTestCase class patched_time: count = 0 @classmethod def time(cls): result = cls.count cls.count += 1 return result class TestIterTimerHook(RunnerTestCase): @patch('mmengine.hooks.iter_timer_hook.time', patched_time) def test_before_iter(self): runner = self.build_runner(self.epoch_based_cfg) hook = self._get_iter_timer_hook(runner) for mode in ('train', 'val', 'test'): hook._before_epoch(runner) hook._before_iter(runner, batch_idx=1, mode=mode) time = runner.message_hub.get_scalar( f'{mode}/data_time')._log_history self.assertEqual(list(time)[-1], 1) @patch('mmengine.hooks.iter_timer_hook.time', patched_time) def test_after_iter(self): cfg = copy.deepcopy(self.iter_based_cfg) cfg.train_cfg.max_iters = 100 runner = self.build_runner(cfg) hook = self._get_iter_timer_hook(runner) hook.before_run(runner) hook._before_epoch(runner) # 4 iteration per epoch, totally 2 epochs # Under pathced_time, before_iter will cost "1s" and after_iter will # cost "1s", so the total time for each iteration is 2s. for i in range(10): hook.before_train_iter(runner, i) hook.after_train_iter(runner, i) runner.train_loop._iter += 1 # Left 90 iterations, so the ETA should be 90 * 2s self.assertEqual(runner.message_hub.get_info('eta'), 180) hook.after_train_epoch(runner) for i in range(2): hook.before_val_iter(runner, i) hook.after_val_iter(runner, batch_idx=i) self.assertEqual(runner.message_hub.get_info('eta'), 4) for i in range(2, 4): hook.before_val_iter(runner, i) hook.after_val_iter(runner, batch_idx=i) hook.after_val_epoch(runner) self.assertEqual(runner.message_hub.get_info('eta'), 0) for i in range(2): hook.before_test_iter(runner, i) hook.after_test_iter(runner, batch_idx=i) self.assertEqual(runner.message_hub.get_info('eta'), 4) for i in range(2, 4): hook.before_test_iter(runner, i) hook.after_test_iter(runner, batch_idx=i) hook.after_test_epoch(runner) self.assertEqual(runner.message_hub.get_info('eta'), 0) def test_with_runner(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) cfg.train_cfg.val_interval = 1e6 # disable validation with patch('mmengine.hooks.iter_timer_hook.time', patched_time): runner.train() # 4 iteration per epoch, totally 2 epochs # Under pathced_time, before_iter will cost "1s" and after_iter will # cost "1s", so the total time for each iteration is 2s. train_time = runner.message_hub.log_scalars['train/time']._log_history self.assertEqual(len(train_time), 8) self.assertListEqual(list(train_time), [2] * 8) eta = runner.message_hub.runtime_info['eta'] self.assertEqual(eta, 0) def _get_iter_timer_hook(self, runner): for hook in runner.hooks: if isinstance(hook, IterTimerHook): return hook