From 9a61b389e7c65f9d6b27bca7ed5b0971c02bbd65 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 29 Mar 2022 11:40:38 +0800 Subject: [PATCH] [Refactor] Add batch_idx to hook input. (#140) * [Refactor] Add batch_idx to hook input. * update --- mmengine/hooks/checkpoint_hook.py | 16 +++-- mmengine/hooks/empty_cache_hook.py | 2 + mmengine/hooks/hook.py | 68 +++++++++++++++---- mmengine/hooks/iter_timer_hook.py | 13 ++-- mmengine/hooks/logger_hook.py | 15 ++-- mmengine/hooks/naive_visualization_hook.py | 2 + mmengine/hooks/optimizer_hook.py | 4 +- mmengine/hooks/param_scheduler_hook.py | 4 +- mmengine/runner/loops.py | 34 ++++++---- mmengine/runner/runner.py | 6 -- tests/test_hook/test_checkpoint_hook.py | 7 +- tests/test_hook/test_empty_cache_hook.py | 10 +-- tests/test_hook/test_hook.py | 12 ++-- tests/test_hook/test_iter_timer_hook.py | 32 ++++----- tests/test_hook/test_logger_hook.py | 23 ++++--- .../test_naive_visualization_hook.py | 15 ++-- tests/test_hook/test_optimizer_hook.py | 6 +- tests/test_hook/test_param_scheduler_hook.py | 16 ++--- tests/test_hook/test_sync_buffers_hook.py | 8 +-- tests/test_runner/test_runner.py | 21 +++--- 20 files changed, 186 insertions(+), 128 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 17e19737..81e7f36d 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -20,31 +20,31 @@ class CheckpointHook(Hook): Args: interval (int): The saving period. If ``by_epoch=True``, interval indicates epochs, otherwise it indicates iterations. - Default: -1, which means "never". + Defaults to -1, which means "never". by_epoch (bool): Saving checkpoints by epoch or by iteration. Default: True. save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. - Default: True. + Defaults to True. save_param_scheduler (bool): Whether to save param_scheduler state_dict in the checkpoint. It is usually used for resuming experiments. - Default: True. + Defaults to True. out_dir (str, optional | Path): The root directory to save checkpoints. If not specified, ``runner.work_dir`` will be used by default. If specified, the ``out_dir`` will be the concatenation of ``out_dir`` and the last level directory of ``runner.work_dir``. For example, if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is ``./work_dir/cur_exp``, then the ckpt will be saved in - ``./tmp/cur_exp``. Deafule to None. + ``./tmp/cur_exp``. Defaults to None. max_keep_ckpts (int): The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. - Default: -1, which means unlimited. + Defaults to -1, which means unlimited. save_last (bool): Whether to force the last checkpoint to be - saved regardless of interval. Default: True. + saved regardless of interval. Defaults to True. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. - Default: None. + Defaults to None. """ out_dir: str @@ -177,12 +177,14 @@ class CheckpointHook(Hook): def after_train_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs=Optional[dict]) -> None: """Save the checkpoint and synchronize buffers after each iteration. Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index cb20a744..45a1df11 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -36,6 +36,7 @@ class EmptyCacheHook(Hook): def _after_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[dict, Sequence[BaseDataSample]]] = None, @@ -44,6 +45,7 @@ class EmptyCacheHook(Hook): Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 7eeb5e54..bafe5f31 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -164,41 +164,57 @@ class Hook: """ self._after_epoch(runner, mode='test') - def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None: + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each training iteration. Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self._before_iter(runner, data_batch=data_batch, mode='train') + self._before_iter( + runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') - def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None: + def before_val_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each validation iteration. Args: runner (Runner): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self._before_iter(runner, data_batch=data_batch, mode='val') + self._before_iter( + runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') - def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None: + def before_test_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each test iteration. Args: runner (Runner): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self._before_iter(runner, data_batch=data_batch, mode='test') + self._before_iter( + runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') def after_train_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: """All subclasses should override this method, if they need any @@ -206,16 +222,22 @@ class Hook: Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ self._after_iter( - runner, data_batch=data_batch, outputs=outputs, mode='train') + runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='train') def after_val_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) \ -> None: @@ -224,17 +246,23 @@ class Hook: Args: runner (Runner): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ self._after_iter( - runner, data_batch=data_batch, outputs=outputs, mode='val') + runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='val') def after_test_iter( self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any @@ -242,13 +270,18 @@ class Hook: Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the test loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ self._after_iter( - runner, data_batch=data_batch, outputs=outputs, mode='test') + runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='test') def _before_epoch(self, runner, mode: str = 'train') -> None: """All subclasses should override this method, if they need any @@ -274,6 +307,7 @@ class Hook: def _before_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, mode: str = 'train') -> None: """All subclasses should override this method, if they need any @@ -282,6 +316,7 @@ class Hook: Args: runner (Runner): The runner of the training, validation or testing process. + batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. @@ -290,6 +325,7 @@ class Hook: def _after_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[Sequence[BaseDataSample], dict]] = None, @@ -300,6 +336,7 @@ class Hook: Args: runner (Runner): The runner of the training, validation or testing process. + batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. @@ -321,12 +358,12 @@ class Hook: """ return (runner.epoch + 1) % n == 0 if n > 0 else False - def every_n_inner_iters(self, runner, n: int) -> bool: + def every_n_inner_iters(self, inner_iter: int, n: int) -> bool: """Test whether current inner iteration can be evenly divided by n. Args: - runner (Runner): The runner of the training, validation or testing - process. + inner_iter (int): Current inner_iter of the training, validation + or testing loop. n (int): Whether current inner iteration can be evenly divided by n. @@ -334,7 +371,7 @@ class Hook: bool: Whether current inner iteration can be evenly divided by n. """ - return (runner.inner_iter + 1) % n == 0 if n > 0 else False + return (inner_iter + 1) % n == 0 if n > 0 else False def every_n_iters(self, runner, n: int) -> bool: """Test whether current iteration can be evenly divided by n. @@ -350,18 +387,19 @@ class Hook: """ return (runner.iter + 1) % n == 0 if n > 0 else False - def end_of_epoch(self, runner) -> bool: + def end_of_epoch(self, runner, batch_idx: int) -> bool: """Check whether the current iteration reaches the last iteration of current dataloader. Args: runner (Runner): The runner of the training, validation or testing process. + batch_idx (int): The index of the current batch in the loop. Returns: bool: Whether reaches the end of current epoch or not. """ - return runner.inner_iter + 1 == len(runner.cur_dataloader) + return batch_idx + 1 == len(runner.cur_dataloader) def is_last_train_epoch(self, runner) -> bool: """Test whether current epoch is the last train epoch. diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index 824774ed..a18231bd 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -13,7 +13,7 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] class IterTimerHook(Hook): """A hook that logs the time spent during iteration. - Eg. ``data_time`` for loading data and ``time`` for a model train step. + E.g. ``data_time`` for loading data and ``time`` for a model train step. """ priority = 'NORMAL' @@ -29,12 +29,14 @@ class IterTimerHook(Hook): def _before_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, mode: str = 'train') -> None: """Logging time for loading data and update the time flag. Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. @@ -45,15 +47,16 @@ class IterTimerHook(Hook): def _after_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, - outputs: - Optional[Union[dict, Sequence[BaseDataSample]]] = None, - mode: str = 'train') \ - -> None: + outputs: Optional[Union[dict, + Sequence[BaseDataSample]]] = None, + mode: str = 'train') -> None: """Logging time for a iteration and update the time flag. Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 64c7868e..492716c5 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -121,6 +121,7 @@ class LoggerHook(Hook): keep_local=True, file_client_args=None, ): + self._inner_iter = 0 self.by_epoch = by_epoch self.interval = interval self.custom_keys = custom_keys if custom_keys is not None else dict() @@ -174,27 +175,31 @@ class LoggerHook(Hook): def after_train_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: """Record training logs. Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[BaseDataSample], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ + self._inner_iter = batch_idx if runner.meta is not None and 'exp_name' in runner.meta: if (self.every_n_iters(runner, self.interval_exp_name)) or ( - self.by_epoch and self.end_of_epoch(runner)): + self.by_epoch and self.end_of_epoch(runner, batch_idx)): exp_info = f'Exp name: {runner.meta["exp_name"]}' runner.logger.info(exp_info) - if self.by_epoch and self.every_n_inner_iters(runner, self.interval): + if self.by_epoch and self.every_n_inner_iters(batch_idx, + self.interval): self._log_train(runner) elif not self.by_epoch and self.every_n_iters(runner, self.interval): self._log_train(runner) - elif self.end_of_epoch(runner) and not self.ignore_last: + elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last: # `runner.max_iters` may not be divisible by `self.interval`. if # `self.ignore_last==True`, the log of remaining iterations will # be recorded (Epoch [4][1000/1007], the logs of 998-1007 @@ -346,7 +351,7 @@ class LoggerHook(Hook): 'The value of windows size must equal to LoggerHook.interval' return window_size elif window_size == 'epoch': - return runner.inner_iter + 1 + return self._inner_iter + 1 elif window_size == 'global': return runner.iter + 1 else: @@ -505,7 +510,7 @@ class LoggerHook(Hook): int: The current global iter or inner iter. """ if self.by_epoch and inner_iter: - current_iter = runner.inner_iter + 1 + current_iter = self._inner_iter + 1 else: current_iter = runner.iter + 1 return current_iter diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py index 434de95f..a3a3b092 100644 --- a/mmengine/hooks/naive_visualization_hook.py +++ b/mmengine/hooks/naive_visualization_hook.py @@ -40,12 +40,14 @@ class NaiveVisualizationHook(Hook): def after_test_iter( self, runner, + batch_idx: int, data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """Show or Write the predicted results. Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the test loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index 418bf06d..03b7d606 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -58,6 +58,7 @@ class OptimizerHook(Hook): def after_train_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: """All operations need to be finished after each training iteration. @@ -69,12 +70,13 @@ class OptimizerHook(Hook): - Compute the gradient of model parameters. - - Clip the gradidents of each parameters. (optional) + - Clip the gradients of each parameter. (optional) - Update model parameters with gradients. Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index cb2e5fbd..f3531cd2 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -10,19 +10,21 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] @HOOKS.register_module() class ParamSchedulerHook(Hook): - """A hook to update some hyper-parameters in optimizer, e.g learning rate + """A hook to update some hyper-parameters in optimizer, e.g., learning rate and momentum.""" priority = 'LOW' def after_train_iter(self, runner, + batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: """Call step function for each scheduler after each iteration. Args: runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 6403a3a4..b725dfcc 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -70,9 +70,8 @@ class EpochBasedTrainLoop(BaseLoop): data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data from dataloader. """ - self.runner._inner_iter = idx - - self.runner.call_hook('before_train_iter', data_batch=data_batch) + self.runner.call_hook( + 'before_train_iter', batch_idx=idx, data_batch=data_batch) # outputs should be a dict containing one or multiple loss tensors self.runner.outputs = self.runner.model(data_batch, return_loss=True) @@ -82,6 +81,7 @@ class EpochBasedTrainLoop(BaseLoop): self.runner.call_hook( 'after_train_iter', + batch_idx=idx, data_batch=data_batch, outputs=self.runner.outputs) @@ -126,9 +126,6 @@ class IterBasedTrainLoop(BaseLoop): if (self.runner.val_loop is not None and self.runner._iter % self.runner.val_loop.interval == 0): self.runner.val_loop.run() - # reset inner_iter to 0 to ensure it counts as expected in - # train loop - self.runner._inner_iter = 0 self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train') @@ -141,7 +138,10 @@ class IterBasedTrainLoop(BaseLoop): data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data from dataloader. """ - self.runner.call_hook('before_train_iter', data_batch=data_batch) + self.runner.call_hook( + 'before_train_iter', + batch_idx=self.runner._iter, + data_batch=data_batch) # outputs should be a dict containing loss tensor self.runner.outputs = self.runner.model(data_batch, return_loss=True) @@ -151,10 +151,10 @@ class IterBasedTrainLoop(BaseLoop): self.runner.call_hook( 'after_train_iter', + batch_idx=self.runner._iter, data_batch=data_batch, outputs=self.runner.outputs) self.runner._iter += 1 - self.runner._inner_iter += 1 @LOOPS.register_module() @@ -208,13 +208,16 @@ class ValLoop(BaseLoop): data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data from dataloader. """ - self.runner._inner_iter = idx - self.runner.call_hook('before_val_iter', data_batch=data_batch) + self.runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) # outputs should be sequence of BaseDataSample outputs = self.runner.model(data_batch) self.evaluator.process(data_batch, outputs) self.runner.call_hook( - 'after_val_iter', data_batch=data_batch, outputs=outputs) + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) @LOOPS.register_module() @@ -263,10 +266,13 @@ class TestLoop(BaseLoop): data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data from dataloader. """ - self.runner._inner_iter = idx - self.runner.call_hook('before_test_iter', data_batch=data_batch) + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) # predictions should be sequence of BaseDataSample predictions = self.runner.model(data_batch) self.evaluator.process(data_batch, predictions) self.runner.call_hook( - 'after_test_iter', data_batch=data_batch, outputs=predictions) + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=predictions) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 6d458c02..88fe08ff 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -229,7 +229,6 @@ class Runner: self._epoch = 0 self._iter = 0 - self._inner_iter = 0 # lazy initialization training_related = [ @@ -400,11 +399,6 @@ class Runner: """int: Current epoch.""" return self._iter - @property - def inner_iter(self): - """int: Current iteration.""" - return self._inner_iter - @property def launcher(self): """str: Way to launcher multi processes.""" diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py index 3cf661c0..7fabecd5 100644 --- a/tests/test_hook/test_checkpoint_hook.py +++ b/tests/test_hook/test_checkpoint_hook.py @@ -100,19 +100,20 @@ class TestCheckpointHook: runner = Mock() runner.work_dir = './tmp' runner.iter = 9 + batch_idx = 9 runner.meta = dict() runner.model = Mock() # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook.before_run(runner) - checkpoint_hook.after_train_iter(runner) + checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) assert runner.meta.get('hook_msgs', None) is None # by epoch is False checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_run(runner) - checkpoint_hook.after_train_iter(runner) + checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) assert (runner.iter + 1) % 2 == 0 assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth' @@ -129,5 +130,5 @@ class TestCheckpointHook: checkpoint_hook = CheckpointHook( interval=2, by_epoch=False, max_keep_ckpts=1) checkpoint_hook.before_run(runner) - checkpoint_hook.after_train_iter(runner) + checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) assert not os.path.exists(f'{tempo_dir}/iter_8.pth') diff --git a/tests/test_hook/test_empty_cache_hook.py b/tests/test_hook/test_empty_cache_hook.py index dc8ce8fc..e909fc5d 100644 --- a/tests/test_hook/test_empty_cache_hook.py +++ b/tests/test_hook/test_empty_cache_hook.py @@ -7,8 +7,8 @@ from mmengine.hooks import EmptyCacheHook class TestEmptyCacheHook: def test_emtpy_cache_hook(self): - Hook = EmptyCacheHook(True, True, True) - Runner = Mock() - Hook._after_iter(Runner) - Hook._before_epoch(Runner) - Hook._after_epoch(Runner) + hook = EmptyCacheHook(True, True, True) + runner = Mock() + hook._after_iter(runner, 0) + hook._before_epoch(runner) + hook._after_epoch(runner) diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py index d55f8a66..db80ed4a 100644 --- a/tests/test_hook/test_hook.py +++ b/tests/test_hook/test_hook.py @@ -136,11 +136,9 @@ class TestHook: def test_every_n_inner_iters(self): hook = Hook() - runner = Mock() for i in range(100): - runner.inner_iter = i - return_val = hook.every_n_inner_iters(runner, 3) + return_val = hook.every_n_inner_iters(i, 3) if (i + 1) % 3 == 0: assert return_val else: @@ -162,15 +160,15 @@ class TestHook: runner = Mock() # last inner iter - runner.inner_iter = 1 + batch_idx = 1 runner.cur_dataloader.__len__ = Mock(return_value=2) runner.cur_dataloader.__len__ = Mock(return_value=2) - return_val = hook.end_of_epoch(runner) + return_val = hook.end_of_epoch(runner, batch_idx) assert return_val # not the last inner iter - runner.inner_iter = 0 - return_val = hook.end_of_epoch(runner) + batch_idx = 0 + return_val = hook.end_of_epoch(runner, batch_idx) assert not return_val def test_is_last_train_epoch(self): diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py index 8b20d4f1..44e09bc3 100644 --- a/tests/test_hook/test_iter_timer_hook.py +++ b/tests/test_hook/test_iter_timer_hook.py @@ -7,23 +7,23 @@ from mmengine.hooks import IterTimerHook class TestIterTimerHook: def test_before_epoch(self): - Hook = IterTimerHook() - Runner = Mock() - Hook._before_epoch(Runner) - assert isinstance(Hook.t, float) + hook = IterTimerHook() + runner = Mock() + hook._before_epoch(runner) + assert isinstance(hook.t, float) def test_before_iter(self): - Hook = IterTimerHook() - Runner = Mock() - Runner.log_buffer = dict() - Hook._before_epoch(Runner) - Hook._before_iter(Runner) - Runner.message_hub.update_log.assert_called() + hook = IterTimerHook() + runner = Mock() + runner.log_buffer = dict() + hook._before_epoch(runner) + hook._before_iter(runner, 0) + runner.message_hub.update_log.assert_called() def test_after_iter(self): - Hook = IterTimerHook() - Runner = Mock() - Runner.log_buffer = dict() - Hook._before_epoch(Runner) - Hook._after_iter(Runner) - Runner.message_hub.update_log.assert_called() + hook = IterTimerHook() + runner = Mock() + runner.log_buffer = dict() + hook._before_epoch(runner) + hook._after_iter(runner, 0) + runner.message_hub.update_log.assert_called() diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index d92ac9bc..6fac1a93 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -85,14 +85,15 @@ class TestLoggerHook: # Test LoggerHook by iter. runner = MagicMock() runner.iter = 10 + batch_idx = 5 logger_hook = LoggerHook(by_epoch=False) logger_hook._log_train = MagicMock() - logger_hook.after_train_iter(runner) + logger_hook.after_train_iter(runner, batch_idx=batch_idx) # `cur_iter=10+1`, which cannot be exact division by # `logger_hook.interval` logger_hook._log_train.assert_not_called() runner.iter = 9 - logger_hook.after_train_iter(runner) + logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_called() # Test LoggerHook by epoch. @@ -100,19 +101,19 @@ class TestLoggerHook: logger_hook._log_train = MagicMock() # Only `runner.inner_iter` will work. runner.iter = 9 - runner.inner_iter = 10 - logger_hook.after_train_iter(runner) + batch_idx = 10 + logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_not_called() - runner.inner_iter = 9 - logger_hook.after_train_iter(runner) + batch_idx = 9 + logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_called() # Test end of the epoch. logger_hook = LoggerHook(by_epoch=True, ignore_last=False) logger_hook._log_train = MagicMock() runner.cur_dataloader = [0] * 5 - runner.inner_iter = 4 - logger_hook.after_train_iter(runner) + batch_idx = 4 + logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_called() # Test print exp_name @@ -120,7 +121,7 @@ class TestLoggerHook: logger_hook = LoggerHook() runner.logger = MagicMock() logger_hook._log_train = MagicMock() - logger_hook.after_train_iter(runner) + logger_hook.after_train_iter(runner, batch_idx=batch_idx) runner.logger.info.assert_called_with( f'Exp name: {runner.meta["exp_name"]}') @@ -137,6 +138,7 @@ class TestLoggerHook: runner.meta = dict(exp_name='retinanet') # Prepare LoggerHook logger_hook = LoggerHook(by_epoch=by_epoch) + logger_hook._inner_iter = 1 logger_hook.writer = MagicMock() logger_hook.time_sec_tot = 1000 logger_hook.start_iter = 0 @@ -220,6 +222,7 @@ class TestLoggerHook: def test_get_window_size(self): runner = self._setup_runner() logger_hook = LoggerHook() + logger_hook._inner_iter = 1 # Test get window size by name. assert logger_hook._get_window_size(runner, 'epoch') == 2 assert logger_hook._get_window_size(runner, 'global') == 11 @@ -313,6 +316,7 @@ class TestLoggerHook: def test_get_iter(self): runner = self._setup_runner() logger_hook = LoggerHook() + logger_hook._inner_iter = 1 # Get global iter when `inner_iter=False` iter = logger_hook._get_iter(runner) assert iter == 11 @@ -338,7 +342,6 @@ class TestLoggerHook: runner = MagicMock() runner.epoch = 1 runner.cur_dataloader = [0] * 5 - runner.inner_iter = 1 runner.iter = 10 runner.train_loop.max_iters = 50 logger = logging.getLogger() diff --git a/tests/test_hook/test_naive_visualization_hook.py b/tests/test_hook/test_naive_visualization_hook.py index beb053a4..81977d44 100644 --- a/tests/test_hook/test_naive_visualization_hook.py +++ b/tests/test_hook/test_naive_visualization_hook.py @@ -11,9 +11,10 @@ class TestNaiveVisualizationHook: def test_after_train_iter(self): naive_visualization_hook = NaiveVisualizationHook() - Runner = Mock(iter=1) - Runner.writer.add_image = Mock() + runner = Mock(iter=1) + runner.writer.add_image = Mock() inputs = torch.randn(1, 3, 15, 15) + batch_idx = 10 # test with normalize, resize, pad gt_datasamples = [ BaseDataSample( @@ -28,7 +29,7 @@ class TestNaiveVisualizationHook: ] pred_datasamples = [BaseDataSample()] data_batch = (inputs, gt_datasamples) - naive_visualization_hook.after_test_iter(Runner, data_batch, + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with resize, pad gt_datasamples = [ @@ -42,7 +43,7 @@ class TestNaiveVisualizationHook: ] pred_datasamples = [BaseDataSample()] data_batch = (inputs, gt_datasamples) - naive_visualization_hook.after_test_iter(Runner, data_batch, + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only resize gt_datasamples = [ @@ -55,7 +56,7 @@ class TestNaiveVisualizationHook: ] pred_datasamples = [BaseDataSample()] data_batch = (inputs, gt_datasamples) - naive_visualization_hook.after_test_iter(Runner, data_batch, + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only pad @@ -69,7 +70,7 @@ class TestNaiveVisualizationHook: ] pred_datasamples = [BaseDataSample()] data_batch = (inputs, gt_datasamples) - naive_visualization_hook.after_test_iter(Runner, data_batch, + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test no transform @@ -80,5 +81,5 @@ class TestNaiveVisualizationHook: ] pred_datasamples = [BaseDataSample()] data_batch = (inputs, gt_datasamples) - naive_visualization_hook.after_test_iter(Runner, data_batch, + naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py index 8ab12814..5d04ca3f 100644 --- a/tests/test_hook/test_optimizer_hook.py +++ b/tests/test_hook/test_optimizer_hook.py @@ -73,7 +73,7 @@ class TestOptimizerHook: wraps=optimizer_hook.detect_anomalous_parameters) optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads) - optimizer_hook.after_train_iter(dummy_runner) + optimizer_hook.after_train_iter(dummy_runner, 0) # assert the parameters of conv2 and conv3 are not in the # computational graph which is with x1.sum() as root. assert 'conv2.weight' in dummy_runner.logger.msg @@ -89,7 +89,7 @@ class TestOptimizerHook: dummy_runner.outputs['loss'] = model(x)[1].sum() dummy_runner.logger.msg = '' - optimizer_hook.after_train_iter(dummy_runner) + optimizer_hook.after_train_iter(dummy_runner, 0) # assert the parameters of conv3 are not in the computational graph assert 'conv3.weight' in dummy_runner.logger.msg assert 'conv3.bias' in dummy_runner.logger.msg @@ -107,7 +107,7 @@ class TestOptimizerHook: dummy_runner.outputs['loss'].backward = Mock( wraps=dummy_runner.outputs['loss'].backward) - optimizer_hook.after_train_iter(dummy_runner) + optimizer_hook.after_train_iter(dummy_runner, 0) dummy_runner.optimizer.step.assert_called() dummy_runner.outputs['loss'].backward.assert_called() diff --git a/tests/test_hook/test_param_scheduler_hook.py b/tests/test_hook/test_param_scheduler_hook.py index 8b569d51..3c047962 100644 --- a/tests/test_hook/test_param_scheduler_hook.py +++ b/tests/test_hook/test_param_scheduler_hook.py @@ -7,21 +7,21 @@ from mmengine.hooks import ParamSchedulerHook class TestParamSchedulerHook: def test_after_iter(self): - Hook = ParamSchedulerHook() - Runner = Mock() + hook = ParamSchedulerHook() + runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = False - Runner.param_schedulers = [scheduler] - Hook.after_train_iter(Runner) + runner.param_schedulers = [scheduler] + hook.after_train_iter(runner, 0) scheduler.step.assert_called() def test_after_epoch(self): - Hook = ParamSchedulerHook() - Runner = Mock() + hook = ParamSchedulerHook() + runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = True - Runner.param_schedulers = [scheduler] - Hook.after_train_epoch(Runner) + runner.param_schedulers = [scheduler] + hook.after_train_epoch(runner) scheduler.step.assert_called() diff --git a/tests/test_hook/test_sync_buffers_hook.py b/tests/test_hook/test_sync_buffers_hook.py index 1c0b6295..c7c64287 100644 --- a/tests/test_hook/test_sync_buffers_hook.py +++ b/tests/test_hook/test_sync_buffers_hook.py @@ -7,7 +7,7 @@ from mmengine.hooks import SyncBuffersHook class TestSyncBuffersHook: def test_sync_buffers_hook(self): - Runner = Mock() - Runner.model = Mock() - Hook = SyncBuffersHook() - Hook._after_epoch(Runner) + runner = Mock() + runner.model = Mock() + hook = SyncBuffersHook() + hook._after_epoch(runner) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 020c2dfd..77b341bd 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -720,8 +720,8 @@ class TestRunner(TestCase): epoch_targets = [i for i in range(3)] iter_results = [] iter_targets = [i for i in range(4 * 3)] - inner_iter_results = [] - inner_iter_targets = [i for i in range(4)] * 3 # train and val + batch_idx_results = [] + batch_idx_targets = [i for i in range(4)] * 3 # train and val @HOOKS.register_module() class TestEpochHook(Hook): @@ -729,9 +729,9 @@ class TestRunner(TestCase): def before_train_epoch(self, runner): epoch_results.append(runner.epoch) - def before_train_iter(self, runner, data_batch=None): + def before_train_iter(self, runner, batch_idx, data_batch=None): iter_results.append(runner.iter) - inner_iter_results.append(runner.inner_iter) + batch_idx_results.append(batch_idx) self.epoch_based_cfg.custom_hooks = [ dict(type='TestEpochHook', priority=50) @@ -746,7 +746,7 @@ class TestRunner(TestCase): self.assertEqual(result, target) for result, target, in zip(iter_results, iter_targets): self.assertEqual(result, target) - for result, target, in zip(inner_iter_results, inner_iter_targets): + for result, target, in zip(batch_idx_results, batch_idx_targets): self.assertEqual(result, target) time.sleep(1) @@ -754,9 +754,9 @@ class TestRunner(TestCase): # 3. test iter and epoch counter of IterBasedTrainLoop epoch_results = [] iter_results = [] - inner_iter_results = [] + batch_idx_results = [] iter_targets = [i for i in range(12)] - inner_iter_targets = [0, 1, 2, 3] * 3 + batch_idx_targets = [i for i in range(12)] @HOOKS.register_module() class TestIterHook(Hook): @@ -764,9 +764,9 @@ class TestRunner(TestCase): def before_train_epoch(self, runner): epoch_results.append(runner.epoch) - def before_train_iter(self, runner, data_batch=None): + def before_train_iter(self, runner, batch_idx, data_batch=None): iter_results.append(runner.iter) - inner_iter_results.append(runner.inner_iter) + batch_idx_results.append(batch_idx) self.iter_based_cfg.custom_hooks = [ dict(type='TestIterHook', priority=50) @@ -781,7 +781,7 @@ class TestRunner(TestCase): self.assertEqual(epoch_results[0], 0) for result, target, in zip(iter_results, iter_targets): self.assertEqual(result, target) - for result, target, in zip(inner_iter_results, inner_iter_targets): + for result, target, in zip(batch_idx_results, batch_idx_targets): self.assertEqual(result, target) def test_val(self): @@ -1056,7 +1056,6 @@ class TestRunner(TestCase): runner.resume(path) self.assertEqual(runner.epoch, 0) self.assertEqual(runner.iter, 12) - self.assertEqual(runner.inner_iter, 0) self.assertTrue(runner._has_loaded) self.assertIsInstance(runner.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)