[Refactor] Add batch_idx to hook input. (#140)

* [Refactor] Add batch_idx to hook input.

* update
This commit is contained in:
RangiLyu 2022-03-29 11:40:38 +08:00 committed by GitHub
parent 563b4bad16
commit 9a61b389e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 186 additions and 128 deletions

View File

@ -20,31 +20,31 @@ class CheckpointHook(Hook):
Args: Args:
interval (int): The saving period. If ``by_epoch=True``, interval interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations. indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never". Defaults to -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration. by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True. Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments. checkpoint. It is usually used for resuming experiments.
Default: True. Defaults to True.
save_param_scheduler (bool): Whether to save param_scheduler state_dict save_param_scheduler (bool): Whether to save param_scheduler state_dict
in the checkpoint. It is usually used for resuming experiments. in the checkpoint. It is usually used for resuming experiments.
Default: True. Defaults to True.
out_dir (str, optional | Path): The root directory to save checkpoints. out_dir (str, optional | Path): The root directory to save checkpoints.
If not specified, ``runner.work_dir`` will be used by default. If If not specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir`` specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``. For example, and the last level directory of ``runner.work_dir``. For example,
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
``./work_dir/cur_exp``, then the ckpt will be saved in ``./work_dir/cur_exp``, then the ckpt will be saved in
``./tmp/cur_exp``. Deafule to None. ``./tmp/cur_exp``. Defaults to None.
max_keep_ckpts (int): The maximum checkpoints to keep. max_keep_ckpts (int): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space. like to delete old ones to save the disk space.
Default: -1, which means unlimited. Defaults to -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be save_last (bool): Whether to force the last checkpoint to be
saved regardless of interval. Default: True. saved regardless of interval. Defaults to True.
file_client_args (dict, optional): Arguments to instantiate a file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details. FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None. Defaults to None.
""" """
out_dir: str out_dir: str
@ -177,12 +177,14 @@ class CheckpointHook(Hook):
def after_train_iter(self, def after_train_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs=Optional[dict]) -> None: outputs=Optional[dict]) -> None:
"""Save the checkpoint and synchronize buffers after each iteration. """Save the checkpoint and synchronize buffers after each iteration.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None. from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model. outputs (dict, optional): Outputs from model.

View File

@ -36,6 +36,7 @@ class EmptyCacheHook(Hook):
def _after_iter(self, def _after_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[Union[dict, outputs: Optional[Union[dict,
Sequence[BaseDataSample]]] = None, Sequence[BaseDataSample]]] = None,
@ -44,6 +45,7 @@ class EmptyCacheHook(Hook):
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None. from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from model. outputs (dict or sequence, optional): Outputs from model.

View File

@ -164,41 +164,57 @@ class Hook:
""" """
self._after_epoch(runner, mode='test') self._after_epoch(runner, mode='test')
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None: def before_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each training iteration. operations before each training iteration.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
""" """
self._before_iter(runner, data_batch=data_batch, mode='train') self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None: def before_val_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each validation iteration. operations before each validation iteration.
Args: Args:
runner (Runner): The runner of the validation process. runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
""" """
self._before_iter(runner, data_batch=data_batch, mode='val') self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='val')
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None: def before_test_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each test iteration. operations before each test iteration.
Args: Args:
runner (Runner): The runner of the testing process. runner (Runner): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
""" """
self._before_iter(runner, data_batch=data_batch, mode='test') self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='test')
def after_train_iter(self, def after_train_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None: outputs: Optional[dict] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
@ -206,16 +222,22 @@ class Hook:
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model. outputs (dict, optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
self._after_iter( self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='train') runner,
batch_idx=batch_idx,
data_batch=data_batch,
outputs=outputs,
mode='train')
def after_val_iter(self, def after_val_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) \ outputs: Optional[Sequence[BaseDataSample]] = None) \
-> None: -> None:
@ -224,17 +246,23 @@ class Hook:
Args: Args:
runner (Runner): The runner of the validation process. runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from outputs (dict or sequence, optional): Outputs from
model. Defaults to None. model. Defaults to None.
""" """
self._after_iter( self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='val') runner,
batch_idx=batch_idx,
data_batch=data_batch,
outputs=outputs,
mode='val')
def after_test_iter( def after_test_iter(
self, self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
@ -242,13 +270,18 @@ class Hook:
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model. outputs (dict, optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
self._after_iter( self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='test') runner,
batch_idx=batch_idx,
data_batch=data_batch,
outputs=outputs,
mode='test')
def _before_epoch(self, runner, mode: str = 'train') -> None: def _before_epoch(self, runner, mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
@ -274,6 +307,7 @@ class Hook:
def _before_iter(self, def _before_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
mode: str = 'train') -> None: mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
@ -282,6 +316,7 @@ class Hook:
Args: Args:
runner (Runner): The runner of the training, validation or testing runner (Runner): The runner of the training, validation or testing
process. process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'. mode (str): Current mode of runner. Defaults to 'train'.
@ -290,6 +325,7 @@ class Hook:
def _after_iter(self, def _after_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[Union[Sequence[BaseDataSample], outputs: Optional[Union[Sequence[BaseDataSample],
dict]] = None, dict]] = None,
@ -300,6 +336,7 @@ class Hook:
Args: Args:
runner (Runner): The runner of the training, validation or testing runner (Runner): The runner of the training, validation or testing
process. process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (Sequence[BaseDataSample], optional): Outputs from model.
@ -321,12 +358,12 @@ class Hook:
""" """
return (runner.epoch + 1) % n == 0 if n > 0 else False return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n: int) -> bool: def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
"""Test whether current inner iteration can be evenly divided by n. """Test whether current inner iteration can be evenly divided by n.
Args: Args:
runner (Runner): The runner of the training, validation or testing inner_iter (int): Current inner_iter of the training, validation
process. or testing loop.
n (int): Whether current inner iteration can be evenly n (int): Whether current inner iteration can be evenly
divided by n. divided by n.
@ -334,7 +371,7 @@ class Hook:
bool: Whether current inner iteration can be evenly bool: Whether current inner iteration can be evenly
divided by n. divided by n.
""" """
return (runner.inner_iter + 1) % n == 0 if n > 0 else False return (inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n: int) -> bool: def every_n_iters(self, runner, n: int) -> bool:
"""Test whether current iteration can be evenly divided by n. """Test whether current iteration can be evenly divided by n.
@ -350,18 +387,19 @@ class Hook:
""" """
return (runner.iter + 1) % n == 0 if n > 0 else False return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner) -> bool: def end_of_epoch(self, runner, batch_idx: int) -> bool:
"""Check whether the current iteration reaches the last iteration of """Check whether the current iteration reaches the last iteration of
current dataloader. current dataloader.
Args: Args:
runner (Runner): The runner of the training, validation or testing runner (Runner): The runner of the training, validation or testing
process. process.
batch_idx (int): The index of the current batch in the loop.
Returns: Returns:
bool: Whether reaches the end of current epoch or not. bool: Whether reaches the end of current epoch or not.
""" """
return runner.inner_iter + 1 == len(runner.cur_dataloader) return batch_idx + 1 == len(runner.cur_dataloader)
def is_last_train_epoch(self, runner) -> bool: def is_last_train_epoch(self, runner) -> bool:
"""Test whether current epoch is the last train epoch. """Test whether current epoch is the last train epoch.

View File

@ -13,7 +13,7 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
class IterTimerHook(Hook): class IterTimerHook(Hook):
"""A hook that logs the time spent during iteration. """A hook that logs the time spent during iteration.
Eg. ``data_time`` for loading data and ``time`` for a model train step. E.g. ``data_time`` for loading data and ``time`` for a model train step.
""" """
priority = 'NORMAL' priority = 'NORMAL'
@ -29,12 +29,14 @@ class IterTimerHook(Hook):
def _before_iter(self, def _before_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
mode: str = 'train') -> None: mode: str = 'train') -> None:
"""Logging time for loading data and update the time flag. """Logging time for loading data and update the time flag.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None. from dataloader. Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'. mode (str): Current mode of runner. Defaults to 'train'.
@ -45,15 +47,16 @@ class IterTimerHook(Hook):
def _after_iter(self, def _after_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: outputs: Optional[Union[dict,
Optional[Union[dict, Sequence[BaseDataSample]]] = None, Sequence[BaseDataSample]]] = None,
mode: str = 'train') \ mode: str = 'train') -> None:
-> None:
"""Logging time for a iteration and update the time flag. """Logging time for a iteration and update the time flag.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None. from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults outputs (dict or sequence, optional): Outputs from model. Defaults

View File

@ -121,6 +121,7 @@ class LoggerHook(Hook):
keep_local=True, keep_local=True,
file_client_args=None, file_client_args=None,
): ):
self._inner_iter = 0
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.interval = interval self.interval = interval
self.custom_keys = custom_keys if custom_keys is not None else dict() self.custom_keys = custom_keys if custom_keys is not None else dict()
@ -174,27 +175,31 @@ class LoggerHook(Hook):
def after_train_iter(self, def after_train_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None: outputs: Optional[dict] = None) -> None:
"""Record training logs. """Record training logs.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None. dataloader. Defaults to None.
outputs (dict, optional): Outputs from model. outputs (dict, optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
self._inner_iter = batch_idx
if runner.meta is not None and 'exp_name' in runner.meta: if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or ( if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner)): self.by_epoch and self.end_of_epoch(runner, batch_idx)):
exp_info = f'Exp name: {runner.meta["exp_name"]}' exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info) runner.logger.info(exp_info)
if self.by_epoch and self.every_n_inner_iters(runner, self.interval): if self.by_epoch and self.every_n_inner_iters(batch_idx,
self.interval):
self._log_train(runner) self._log_train(runner)
elif not self.by_epoch and self.every_n_iters(runner, self.interval): elif not self.by_epoch and self.every_n_iters(runner, self.interval):
self._log_train(runner) self._log_train(runner)
elif self.end_of_epoch(runner) and not self.ignore_last: elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
# `runner.max_iters` may not be divisible by `self.interval`. if # `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will # `self.ignore_last==True`, the log of remaining iterations will
# be recorded (Epoch [4][1000/1007], the logs of 998-1007 # be recorded (Epoch [4][1000/1007], the logs of 998-1007
@ -346,7 +351,7 @@ class LoggerHook(Hook):
'The value of windows size must equal to LoggerHook.interval' 'The value of windows size must equal to LoggerHook.interval'
return window_size return window_size
elif window_size == 'epoch': elif window_size == 'epoch':
return runner.inner_iter + 1 return self._inner_iter + 1
elif window_size == 'global': elif window_size == 'global':
return runner.iter + 1 return runner.iter + 1
else: else:
@ -505,7 +510,7 @@ class LoggerHook(Hook):
int: The current global iter or inner iter. int: The current global iter or inner iter.
""" """
if self.by_epoch and inner_iter: if self.by_epoch and inner_iter:
current_iter = runner.inner_iter + 1 current_iter = self._inner_iter + 1
else: else:
current_iter = runner.iter + 1 current_iter = runner.iter + 1
return current_iter return current_iter

View File

@ -40,12 +40,14 @@ class NaiveVisualizationHook(Hook):
def after_test_iter( def after_test_iter(
self, self,
runner, runner,
batch_idx: int,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Show or Write the predicted results. """Show or Write the predicted results.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None. from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (Sequence[BaseDataSample], optional): Outputs from model.

View File

@ -58,6 +58,7 @@ class OptimizerHook(Hook):
def after_train_iter(self, def after_train_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None: outputs: Optional[dict] = None) -> None:
"""All operations need to be finished after each training iteration. """All operations need to be finished after each training iteration.
@ -69,12 +70,13 @@ class OptimizerHook(Hook):
- Compute the gradient of model parameters. - Compute the gradient of model parameters.
- Clip the gradidents of each parameters. (optional) - Clip the gradients of each parameter. (optional)
- Update model parameters with gradients. - Update model parameters with gradients.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. In order to keep this interface consistent from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here. with other hooks, we keep ``data_batch`` here.

View File

@ -10,19 +10,21 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module() @HOOKS.register_module()
class ParamSchedulerHook(Hook): class ParamSchedulerHook(Hook):
"""A hook to update some hyper-parameters in optimizer, e.g learning rate """A hook to update some hyper-parameters in optimizer, e.g., learning rate
and momentum.""" and momentum."""
priority = 'LOW' priority = 'LOW'
def after_train_iter(self, def after_train_iter(self,
runner, runner,
batch_idx: int,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None: outputs: Optional[dict] = None) -> None:
"""Call step function for each scheduler after each iteration. """Call step function for each scheduler after each iteration.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. In order to keep this interface consistent from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here. with other hooks, we keep ``data_batch`` here.

View File

@ -70,9 +70,8 @@ class EpochBasedTrainLoop(BaseLoop):
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
from dataloader. from dataloader.
""" """
self.runner._inner_iter = idx self.runner.call_hook(
'before_train_iter', batch_idx=idx, data_batch=data_batch)
self.runner.call_hook('before_train_iter', data_batch=data_batch)
# outputs should be a dict containing one or multiple loss tensors # outputs should be a dict containing one or multiple loss tensors
self.runner.outputs = self.runner.model(data_batch, return_loss=True) self.runner.outputs = self.runner.model(data_batch, return_loss=True)
@ -82,6 +81,7 @@ class EpochBasedTrainLoop(BaseLoop):
self.runner.call_hook( self.runner.call_hook(
'after_train_iter', 'after_train_iter',
batch_idx=idx,
data_batch=data_batch, data_batch=data_batch,
outputs=self.runner.outputs) outputs=self.runner.outputs)
@ -126,9 +126,6 @@ class IterBasedTrainLoop(BaseLoop):
if (self.runner.val_loop is not None and if (self.runner.val_loop is not None and
self.runner._iter % self.runner.val_loop.interval == 0): self.runner._iter % self.runner.val_loop.interval == 0):
self.runner.val_loop.run() self.runner.val_loop.run()
# reset inner_iter to 0 to ensure it counts as expected in
# train loop
self.runner._inner_iter = 0
self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train') self.runner.call_hook('after_train')
@ -141,7 +138,10 @@ class IterBasedTrainLoop(BaseLoop):
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
from dataloader. from dataloader.
""" """
self.runner.call_hook('before_train_iter', data_batch=data_batch) self.runner.call_hook(
'before_train_iter',
batch_idx=self.runner._iter,
data_batch=data_batch)
# outputs should be a dict containing loss tensor # outputs should be a dict containing loss tensor
self.runner.outputs = self.runner.model(data_batch, return_loss=True) self.runner.outputs = self.runner.model(data_batch, return_loss=True)
@ -151,10 +151,10 @@ class IterBasedTrainLoop(BaseLoop):
self.runner.call_hook( self.runner.call_hook(
'after_train_iter', 'after_train_iter',
batch_idx=self.runner._iter,
data_batch=data_batch, data_batch=data_batch,
outputs=self.runner.outputs) outputs=self.runner.outputs)
self.runner._iter += 1 self.runner._iter += 1
self.runner._inner_iter += 1
@LOOPS.register_module() @LOOPS.register_module()
@ -208,13 +208,16 @@ class ValLoop(BaseLoop):
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
from dataloader. from dataloader.
""" """
self.runner._inner_iter = idx self.runner.call_hook(
self.runner.call_hook('before_val_iter', data_batch=data_batch) 'before_val_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataSample # outputs should be sequence of BaseDataSample
outputs = self.runner.model(data_batch) outputs = self.runner.model(data_batch)
self.evaluator.process(data_batch, outputs) self.evaluator.process(data_batch, outputs)
self.runner.call_hook( self.runner.call_hook(
'after_val_iter', data_batch=data_batch, outputs=outputs) 'after_val_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
@LOOPS.register_module() @LOOPS.register_module()
@ -263,10 +266,13 @@ class TestLoop(BaseLoop):
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
from dataloader. from dataloader.
""" """
self.runner._inner_iter = idx self.runner.call_hook(
self.runner.call_hook('before_test_iter', data_batch=data_batch) 'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataSample # predictions should be sequence of BaseDataSample
predictions = self.runner.model(data_batch) predictions = self.runner.model(data_batch)
self.evaluator.process(data_batch, predictions) self.evaluator.process(data_batch, predictions)
self.runner.call_hook( self.runner.call_hook(
'after_test_iter', data_batch=data_batch, outputs=predictions) 'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=predictions)

View File

@ -229,7 +229,6 @@ class Runner:
self._epoch = 0 self._epoch = 0
self._iter = 0 self._iter = 0
self._inner_iter = 0
# lazy initialization # lazy initialization
training_related = [ training_related = [
@ -400,11 +399,6 @@ class Runner:
"""int: Current epoch.""" """int: Current epoch."""
return self._iter return self._iter
@property
def inner_iter(self):
"""int: Current iteration."""
return self._inner_iter
@property @property
def launcher(self): def launcher(self):
"""str: Way to launcher multi processes.""" """str: Way to launcher multi processes."""

View File

@ -100,19 +100,20 @@ class TestCheckpointHook:
runner = Mock() runner = Mock()
runner.work_dir = './tmp' runner.work_dir = './tmp'
runner.iter = 9 runner.iter = 9
batch_idx = 9
runner.meta = dict() runner.meta = dict()
runner.model = Mock() runner.model = Mock()
# by epoch is True # by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_run(runner) checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert runner.meta.get('hook_msgs', None) is None assert runner.meta.get('hook_msgs', None) is None
# by epoch is False # by epoch is False
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
checkpoint_hook.before_run(runner) checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert (runner.iter + 1) % 2 == 0 assert (runner.iter + 1) % 2 == 0
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth' assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
@ -129,5 +130,5 @@ class TestCheckpointHook:
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, max_keep_ckpts=1) interval=2, by_epoch=False, max_keep_ckpts=1)
checkpoint_hook.before_run(runner) checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert not os.path.exists(f'{tempo_dir}/iter_8.pth') assert not os.path.exists(f'{tempo_dir}/iter_8.pth')

View File

@ -7,8 +7,8 @@ from mmengine.hooks import EmptyCacheHook
class TestEmptyCacheHook: class TestEmptyCacheHook:
def test_emtpy_cache_hook(self): def test_emtpy_cache_hook(self):
Hook = EmptyCacheHook(True, True, True) hook = EmptyCacheHook(True, True, True)
Runner = Mock() runner = Mock()
Hook._after_iter(Runner) hook._after_iter(runner, 0)
Hook._before_epoch(Runner) hook._before_epoch(runner)
Hook._after_epoch(Runner) hook._after_epoch(runner)

View File

@ -136,11 +136,9 @@ class TestHook:
def test_every_n_inner_iters(self): def test_every_n_inner_iters(self):
hook = Hook() hook = Hook()
runner = Mock()
for i in range(100): for i in range(100):
runner.inner_iter = i return_val = hook.every_n_inner_iters(i, 3)
return_val = hook.every_n_inner_iters(runner, 3)
if (i + 1) % 3 == 0: if (i + 1) % 3 == 0:
assert return_val assert return_val
else: else:
@ -162,15 +160,15 @@ class TestHook:
runner = Mock() runner = Mock()
# last inner iter # last inner iter
runner.inner_iter = 1 batch_idx = 1
runner.cur_dataloader.__len__ = Mock(return_value=2) runner.cur_dataloader.__len__ = Mock(return_value=2)
runner.cur_dataloader.__len__ = Mock(return_value=2) runner.cur_dataloader.__len__ = Mock(return_value=2)
return_val = hook.end_of_epoch(runner) return_val = hook.end_of_epoch(runner, batch_idx)
assert return_val assert return_val
# not the last inner iter # not the last inner iter
runner.inner_iter = 0 batch_idx = 0
return_val = hook.end_of_epoch(runner) return_val = hook.end_of_epoch(runner, batch_idx)
assert not return_val assert not return_val
def test_is_last_train_epoch(self): def test_is_last_train_epoch(self):

View File

@ -7,23 +7,23 @@ from mmengine.hooks import IterTimerHook
class TestIterTimerHook: class TestIterTimerHook:
def test_before_epoch(self): def test_before_epoch(self):
Hook = IterTimerHook() hook = IterTimerHook()
Runner = Mock() runner = Mock()
Hook._before_epoch(Runner) hook._before_epoch(runner)
assert isinstance(Hook.t, float) assert isinstance(hook.t, float)
def test_before_iter(self): def test_before_iter(self):
Hook = IterTimerHook() hook = IterTimerHook()
Runner = Mock() runner = Mock()
Runner.log_buffer = dict() runner.log_buffer = dict()
Hook._before_epoch(Runner) hook._before_epoch(runner)
Hook._before_iter(Runner) hook._before_iter(runner, 0)
Runner.message_hub.update_log.assert_called() runner.message_hub.update_log.assert_called()
def test_after_iter(self): def test_after_iter(self):
Hook = IterTimerHook() hook = IterTimerHook()
Runner = Mock() runner = Mock()
Runner.log_buffer = dict() runner.log_buffer = dict()
Hook._before_epoch(Runner) hook._before_epoch(runner)
Hook._after_iter(Runner) hook._after_iter(runner, 0)
Runner.message_hub.update_log.assert_called() runner.message_hub.update_log.assert_called()

View File

@ -85,14 +85,15 @@ class TestLoggerHook:
# Test LoggerHook by iter. # Test LoggerHook by iter.
runner = MagicMock() runner = MagicMock()
runner.iter = 10 runner.iter = 10
batch_idx = 5
logger_hook = LoggerHook(by_epoch=False) logger_hook = LoggerHook(by_epoch=False)
logger_hook._log_train = MagicMock() logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner) logger_hook.after_train_iter(runner, batch_idx=batch_idx)
# `cur_iter=10+1`, which cannot be exact division by # `cur_iter=10+1`, which cannot be exact division by
# `logger_hook.interval` # `logger_hook.interval`
logger_hook._log_train.assert_not_called() logger_hook._log_train.assert_not_called()
runner.iter = 9 runner.iter = 9
logger_hook.after_train_iter(runner) logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called() logger_hook._log_train.assert_called()
# Test LoggerHook by epoch. # Test LoggerHook by epoch.
@ -100,19 +101,19 @@ class TestLoggerHook:
logger_hook._log_train = MagicMock() logger_hook._log_train = MagicMock()
# Only `runner.inner_iter` will work. # Only `runner.inner_iter` will work.
runner.iter = 9 runner.iter = 9
runner.inner_iter = 10 batch_idx = 10
logger_hook.after_train_iter(runner) logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_not_called() logger_hook._log_train.assert_not_called()
runner.inner_iter = 9 batch_idx = 9
logger_hook.after_train_iter(runner) logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called() logger_hook._log_train.assert_called()
# Test end of the epoch. # Test end of the epoch.
logger_hook = LoggerHook(by_epoch=True, ignore_last=False) logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
logger_hook._log_train = MagicMock() logger_hook._log_train = MagicMock()
runner.cur_dataloader = [0] * 5 runner.cur_dataloader = [0] * 5
runner.inner_iter = 4 batch_idx = 4
logger_hook.after_train_iter(runner) logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called() logger_hook._log_train.assert_called()
# Test print exp_name # Test print exp_name
@ -120,7 +121,7 @@ class TestLoggerHook:
logger_hook = LoggerHook() logger_hook = LoggerHook()
runner.logger = MagicMock() runner.logger = MagicMock()
logger_hook._log_train = MagicMock() logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner) logger_hook.after_train_iter(runner, batch_idx=batch_idx)
runner.logger.info.assert_called_with( runner.logger.info.assert_called_with(
f'Exp name: {runner.meta["exp_name"]}') f'Exp name: {runner.meta["exp_name"]}')
@ -137,6 +138,7 @@ class TestLoggerHook:
runner.meta = dict(exp_name='retinanet') runner.meta = dict(exp_name='retinanet')
# Prepare LoggerHook # Prepare LoggerHook
logger_hook = LoggerHook(by_epoch=by_epoch) logger_hook = LoggerHook(by_epoch=by_epoch)
logger_hook._inner_iter = 1
logger_hook.writer = MagicMock() logger_hook.writer = MagicMock()
logger_hook.time_sec_tot = 1000 logger_hook.time_sec_tot = 1000
logger_hook.start_iter = 0 logger_hook.start_iter = 0
@ -220,6 +222,7 @@ class TestLoggerHook:
def test_get_window_size(self): def test_get_window_size(self):
runner = self._setup_runner() runner = self._setup_runner()
logger_hook = LoggerHook() logger_hook = LoggerHook()
logger_hook._inner_iter = 1
# Test get window size by name. # Test get window size by name.
assert logger_hook._get_window_size(runner, 'epoch') == 2 assert logger_hook._get_window_size(runner, 'epoch') == 2
assert logger_hook._get_window_size(runner, 'global') == 11 assert logger_hook._get_window_size(runner, 'global') == 11
@ -313,6 +316,7 @@ class TestLoggerHook:
def test_get_iter(self): def test_get_iter(self):
runner = self._setup_runner() runner = self._setup_runner()
logger_hook = LoggerHook() logger_hook = LoggerHook()
logger_hook._inner_iter = 1
# Get global iter when `inner_iter=False` # Get global iter when `inner_iter=False`
iter = logger_hook._get_iter(runner) iter = logger_hook._get_iter(runner)
assert iter == 11 assert iter == 11
@ -338,7 +342,6 @@ class TestLoggerHook:
runner = MagicMock() runner = MagicMock()
runner.epoch = 1 runner.epoch = 1
runner.cur_dataloader = [0] * 5 runner.cur_dataloader = [0] * 5
runner.inner_iter = 1
runner.iter = 10 runner.iter = 10
runner.train_loop.max_iters = 50 runner.train_loop.max_iters = 50
logger = logging.getLogger() logger = logging.getLogger()

View File

@ -11,9 +11,10 @@ class TestNaiveVisualizationHook:
def test_after_train_iter(self): def test_after_train_iter(self):
naive_visualization_hook = NaiveVisualizationHook() naive_visualization_hook = NaiveVisualizationHook()
Runner = Mock(iter=1) runner = Mock(iter=1)
Runner.writer.add_image = Mock() runner.writer.add_image = Mock()
inputs = torch.randn(1, 3, 15, 15) inputs = torch.randn(1, 3, 15, 15)
batch_idx = 10
# test with normalize, resize, pad # test with normalize, resize, pad
gt_datasamples = [ gt_datasamples = [
BaseDataSample( BaseDataSample(
@ -28,7 +29,7 @@ class TestNaiveVisualizationHook:
] ]
pred_datasamples = [BaseDataSample()] pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples) data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch, naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples) pred_datasamples)
# test with resize, pad # test with resize, pad
gt_datasamples = [ gt_datasamples = [
@ -42,7 +43,7 @@ class TestNaiveVisualizationHook:
] ]
pred_datasamples = [BaseDataSample()] pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples) data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch, naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples) pred_datasamples)
# test with only resize # test with only resize
gt_datasamples = [ gt_datasamples = [
@ -55,7 +56,7 @@ class TestNaiveVisualizationHook:
] ]
pred_datasamples = [BaseDataSample()] pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples) data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch, naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples) pred_datasamples)
# test with only pad # test with only pad
@ -69,7 +70,7 @@ class TestNaiveVisualizationHook:
] ]
pred_datasamples = [BaseDataSample()] pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples) data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch, naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples) pred_datasamples)
# test no transform # test no transform
@ -80,5 +81,5 @@ class TestNaiveVisualizationHook:
] ]
pred_datasamples = [BaseDataSample()] pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples) data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch, naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples) pred_datasamples)

View File

@ -73,7 +73,7 @@ class TestOptimizerHook:
wraps=optimizer_hook.detect_anomalous_parameters) wraps=optimizer_hook.detect_anomalous_parameters)
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads) optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
optimizer_hook.after_train_iter(dummy_runner) optimizer_hook.after_train_iter(dummy_runner, 0)
# assert the parameters of conv2 and conv3 are not in the # assert the parameters of conv2 and conv3 are not in the
# computational graph which is with x1.sum() as root. # computational graph which is with x1.sum() as root.
assert 'conv2.weight' in dummy_runner.logger.msg assert 'conv2.weight' in dummy_runner.logger.msg
@ -89,7 +89,7 @@ class TestOptimizerHook:
dummy_runner.outputs['loss'] = model(x)[1].sum() dummy_runner.outputs['loss'] = model(x)[1].sum()
dummy_runner.logger.msg = '' dummy_runner.logger.msg = ''
optimizer_hook.after_train_iter(dummy_runner) optimizer_hook.after_train_iter(dummy_runner, 0)
# assert the parameters of conv3 are not in the computational graph # assert the parameters of conv3 are not in the computational graph
assert 'conv3.weight' in dummy_runner.logger.msg assert 'conv3.weight' in dummy_runner.logger.msg
assert 'conv3.bias' in dummy_runner.logger.msg assert 'conv3.bias' in dummy_runner.logger.msg
@ -107,7 +107,7 @@ class TestOptimizerHook:
dummy_runner.outputs['loss'].backward = Mock( dummy_runner.outputs['loss'].backward = Mock(
wraps=dummy_runner.outputs['loss'].backward) wraps=dummy_runner.outputs['loss'].backward)
optimizer_hook.after_train_iter(dummy_runner) optimizer_hook.after_train_iter(dummy_runner, 0)
dummy_runner.optimizer.step.assert_called() dummy_runner.optimizer.step.assert_called()
dummy_runner.outputs['loss'].backward.assert_called() dummy_runner.outputs['loss'].backward.assert_called()

View File

@ -7,21 +7,21 @@ from mmengine.hooks import ParamSchedulerHook
class TestParamSchedulerHook: class TestParamSchedulerHook:
def test_after_iter(self): def test_after_iter(self):
Hook = ParamSchedulerHook() hook = ParamSchedulerHook()
Runner = Mock() runner = Mock()
scheduler = Mock() scheduler = Mock()
scheduler.step = Mock() scheduler.step = Mock()
scheduler.by_epoch = False scheduler.by_epoch = False
Runner.param_schedulers = [scheduler] runner.param_schedulers = [scheduler]
Hook.after_train_iter(Runner) hook.after_train_iter(runner, 0)
scheduler.step.assert_called() scheduler.step.assert_called()
def test_after_epoch(self): def test_after_epoch(self):
Hook = ParamSchedulerHook() hook = ParamSchedulerHook()
Runner = Mock() runner = Mock()
scheduler = Mock() scheduler = Mock()
scheduler.step = Mock() scheduler.step = Mock()
scheduler.by_epoch = True scheduler.by_epoch = True
Runner.param_schedulers = [scheduler] runner.param_schedulers = [scheduler]
Hook.after_train_epoch(Runner) hook.after_train_epoch(runner)
scheduler.step.assert_called() scheduler.step.assert_called()

View File

@ -7,7 +7,7 @@ from mmengine.hooks import SyncBuffersHook
class TestSyncBuffersHook: class TestSyncBuffersHook:
def test_sync_buffers_hook(self): def test_sync_buffers_hook(self):
Runner = Mock() runner = Mock()
Runner.model = Mock() runner.model = Mock()
Hook = SyncBuffersHook() hook = SyncBuffersHook()
Hook._after_epoch(Runner) hook._after_epoch(runner)

View File

@ -720,8 +720,8 @@ class TestRunner(TestCase):
epoch_targets = [i for i in range(3)] epoch_targets = [i for i in range(3)]
iter_results = [] iter_results = []
iter_targets = [i for i in range(4 * 3)] iter_targets = [i for i in range(4 * 3)]
inner_iter_results = [] batch_idx_results = []
inner_iter_targets = [i for i in range(4)] * 3 # train and val batch_idx_targets = [i for i in range(4)] * 3 # train and val
@HOOKS.register_module() @HOOKS.register_module()
class TestEpochHook(Hook): class TestEpochHook(Hook):
@ -729,9 +729,9 @@ class TestRunner(TestCase):
def before_train_epoch(self, runner): def before_train_epoch(self, runner):
epoch_results.append(runner.epoch) epoch_results.append(runner.epoch)
def before_train_iter(self, runner, data_batch=None): def before_train_iter(self, runner, batch_idx, data_batch=None):
iter_results.append(runner.iter) iter_results.append(runner.iter)
inner_iter_results.append(runner.inner_iter) batch_idx_results.append(batch_idx)
self.epoch_based_cfg.custom_hooks = [ self.epoch_based_cfg.custom_hooks = [
dict(type='TestEpochHook', priority=50) dict(type='TestEpochHook', priority=50)
@ -746,7 +746,7 @@ class TestRunner(TestCase):
self.assertEqual(result, target) self.assertEqual(result, target)
for result, target, in zip(iter_results, iter_targets): for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
for result, target, in zip(inner_iter_results, inner_iter_targets): for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
time.sleep(1) time.sleep(1)
@ -754,9 +754,9 @@ class TestRunner(TestCase):
# 3. test iter and epoch counter of IterBasedTrainLoop # 3. test iter and epoch counter of IterBasedTrainLoop
epoch_results = [] epoch_results = []
iter_results = [] iter_results = []
inner_iter_results = [] batch_idx_results = []
iter_targets = [i for i in range(12)] iter_targets = [i for i in range(12)]
inner_iter_targets = [0, 1, 2, 3] * 3 batch_idx_targets = [i for i in range(12)]
@HOOKS.register_module() @HOOKS.register_module()
class TestIterHook(Hook): class TestIterHook(Hook):
@ -764,9 +764,9 @@ class TestRunner(TestCase):
def before_train_epoch(self, runner): def before_train_epoch(self, runner):
epoch_results.append(runner.epoch) epoch_results.append(runner.epoch)
def before_train_iter(self, runner, data_batch=None): def before_train_iter(self, runner, batch_idx, data_batch=None):
iter_results.append(runner.iter) iter_results.append(runner.iter)
inner_iter_results.append(runner.inner_iter) batch_idx_results.append(batch_idx)
self.iter_based_cfg.custom_hooks = [ self.iter_based_cfg.custom_hooks = [
dict(type='TestIterHook', priority=50) dict(type='TestIterHook', priority=50)
@ -781,7 +781,7 @@ class TestRunner(TestCase):
self.assertEqual(epoch_results[0], 0) self.assertEqual(epoch_results[0], 0)
for result, target, in zip(iter_results, iter_targets): for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
for result, target, in zip(inner_iter_results, inner_iter_targets): for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
def test_val(self): def test_val(self):
@ -1056,7 +1056,6 @@ class TestRunner(TestCase):
runner.resume(path) runner.resume(path)
self.assertEqual(runner.epoch, 0) self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12) self.assertEqual(runner.iter, 12)
self.assertEqual(runner.inner_iter, 0)
self.assertTrue(runner._has_loaded) self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD) self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)