[Refactor] Add batch_idx to hook input. (#140)
* [Refactor] Add batch_idx to hook input. * updatepull/139/head^2
parent
563b4bad16
commit
9a61b389e7
|
@ -20,31 +20,31 @@ class CheckpointHook(Hook):
|
|||
Args:
|
||||
interval (int): The saving period. If ``by_epoch=True``, interval
|
||||
indicates epochs, otherwise it indicates iterations.
|
||||
Default: -1, which means "never".
|
||||
Defaults to -1, which means "never".
|
||||
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
||||
Default: True.
|
||||
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||||
checkpoint. It is usually used for resuming experiments.
|
||||
Default: True.
|
||||
Defaults to True.
|
||||
save_param_scheduler (bool): Whether to save param_scheduler state_dict
|
||||
in the checkpoint. It is usually used for resuming experiments.
|
||||
Default: True.
|
||||
Defaults to True.
|
||||
out_dir (str, optional | Path): The root directory to save checkpoints.
|
||||
If not specified, ``runner.work_dir`` will be used by default. If
|
||||
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
||||
and the last level directory of ``runner.work_dir``. For example,
|
||||
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
|
||||
``./work_dir/cur_exp``, then the ckpt will be saved in
|
||||
``./tmp/cur_exp``. Deafule to None.
|
||||
``./tmp/cur_exp``. Defaults to None.
|
||||
max_keep_ckpts (int): The maximum checkpoints to keep.
|
||||
In some cases we want only the latest few checkpoints and would
|
||||
like to delete old ones to save the disk space.
|
||||
Default: -1, which means unlimited.
|
||||
Defaults to -1, which means unlimited.
|
||||
save_last (bool): Whether to force the last checkpoint to be
|
||||
saved regardless of interval. Default: True.
|
||||
saved regardless of interval. Defaults to True.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
Defaults to None.
|
||||
"""
|
||||
out_dir: str
|
||||
|
||||
|
@ -177,12 +177,14 @@ class CheckpointHook(Hook):
|
|||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs=Optional[dict]) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
|
|
|
@ -36,6 +36,7 @@ class EmptyCacheHook(Hook):
|
|||
|
||||
def _after_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Union[dict,
|
||||
Sequence[BaseDataSample]]] = None,
|
||||
|
@ -44,6 +45,7 @@ class EmptyCacheHook(Hook):
|
|||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model.
|
||||
|
|
|
@ -164,41 +164,57 @@ class Hook:
|
|||
"""
|
||||
self._after_epoch(runner, mode='test')
|
||||
|
||||
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
def before_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each training iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self._before_iter(runner, data_batch=data_batch, mode='train')
|
||||
self._before_iter(
|
||||
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
|
||||
|
||||
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
def before_val_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each validation iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self._before_iter(runner, data_batch=data_batch, mode='val')
|
||||
self._before_iter(
|
||||
runner, batch_idx=batch_idx, data_batch=data_batch, mode='val')
|
||||
|
||||
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
def before_test_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each test iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self._before_iter(runner, data_batch=data_batch, mode='test')
|
||||
self._before_iter(
|
||||
runner, batch_idx=batch_idx, data_batch=data_batch, mode='test')
|
||||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -206,16 +222,22 @@ class Hook:
|
|||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._after_iter(
|
||||
runner, data_batch=data_batch, outputs=outputs, mode='train')
|
||||
runner,
|
||||
batch_idx=batch_idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs,
|
||||
mode='train')
|
||||
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) \
|
||||
-> None:
|
||||
|
@ -224,17 +246,23 @@ class Hook:
|
|||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from
|
||||
model. Defaults to None.
|
||||
"""
|
||||
self._after_iter(
|
||||
runner, data_batch=data_batch, outputs=outputs, mode='val')
|
||||
runner,
|
||||
batch_idx=batch_idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs,
|
||||
mode='val')
|
||||
|
||||
def after_test_iter(
|
||||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -242,13 +270,18 @@ class Hook:
|
|||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._after_iter(
|
||||
runner, data_batch=data_batch, outputs=outputs, mode='test')
|
||||
runner,
|
||||
batch_idx=batch_idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs,
|
||||
mode='test')
|
||||
|
||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -274,6 +307,7 @@ class Hook:
|
|||
|
||||
def _before_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -282,6 +316,7 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
|
@ -290,6 +325,7 @@ class Hook:
|
|||
|
||||
def _after_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Union[Sequence[BaseDataSample],
|
||||
dict]] = None,
|
||||
|
@ -300,6 +336,7 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
|
@ -321,12 +358,12 @@ class Hook:
|
|||
"""
|
||||
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_inner_iters(self, runner, n: int) -> bool:
|
||||
def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
|
||||
"""Test whether current inner iteration can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
inner_iter (int): Current inner_iter of the training, validation
|
||||
or testing loop.
|
||||
n (int): Whether current inner iteration can be evenly
|
||||
divided by n.
|
||||
|
||||
|
@ -334,7 +371,7 @@ class Hook:
|
|||
bool: Whether current inner iteration can be evenly
|
||||
divided by n.
|
||||
"""
|
||||
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
|
||||
return (inner_iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_iters(self, runner, n: int) -> bool:
|
||||
"""Test whether current iteration can be evenly divided by n.
|
||||
|
@ -350,18 +387,19 @@ class Hook:
|
|||
"""
|
||||
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def end_of_epoch(self, runner) -> bool:
|
||||
def end_of_epoch(self, runner, batch_idx: int) -> bool:
|
||||
"""Check whether the current iteration reaches the last iteration of
|
||||
current dataloader.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
|
||||
Returns:
|
||||
bool: Whether reaches the end of current epoch or not.
|
||||
"""
|
||||
return runner.inner_iter + 1 == len(runner.cur_dataloader)
|
||||
return batch_idx + 1 == len(runner.cur_dataloader)
|
||||
|
||||
def is_last_train_epoch(self, runner) -> bool:
|
||||
"""Test whether current epoch is the last train epoch.
|
||||
|
|
|
@ -13,7 +13,7 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
|||
class IterTimerHook(Hook):
|
||||
"""A hook that logs the time spent during iteration.
|
||||
|
||||
Eg. ``data_time`` for loading data and ``time`` for a model train step.
|
||||
E.g. ``data_time`` for loading data and ``time`` for a model train step.
|
||||
"""
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
@ -29,12 +29,14 @@ class IterTimerHook(Hook):
|
|||
|
||||
def _before_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""Logging time for loading data and update the time flag.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
|
@ -45,15 +47,16 @@ class IterTimerHook(Hook):
|
|||
|
||||
def _after_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs:
|
||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None,
|
||||
mode: str = 'train') \
|
||||
-> None:
|
||||
outputs: Optional[Union[dict,
|
||||
Sequence[BaseDataSample]]] = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""Logging time for a iteration and update the time flag.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||
|
|
|
@ -121,6 +121,7 @@ class LoggerHook(Hook):
|
|||
keep_local=True,
|
||||
file_client_args=None,
|
||||
):
|
||||
self._inner_iter = 0
|
||||
self.by_epoch = by_epoch
|
||||
self.interval = interval
|
||||
self.custom_keys = custom_keys if custom_keys is not None else dict()
|
||||
|
@ -174,27 +175,31 @@ class LoggerHook(Hook):
|
|||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Record training logs.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._inner_iter = batch_idx
|
||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
||||
self.by_epoch and self.end_of_epoch(runner)):
|
||||
self.by_epoch and self.end_of_epoch(runner, batch_idx)):
|
||||
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
||||
runner.logger.info(exp_info)
|
||||
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
|
||||
if self.by_epoch and self.every_n_inner_iters(batch_idx,
|
||||
self.interval):
|
||||
self._log_train(runner)
|
||||
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
||||
self._log_train(runner)
|
||||
elif self.end_of_epoch(runner) and not self.ignore_last:
|
||||
elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
|
||||
# `runner.max_iters` may not be divisible by `self.interval`. if
|
||||
# `self.ignore_last==True`, the log of remaining iterations will
|
||||
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
|
||||
|
@ -346,7 +351,7 @@ class LoggerHook(Hook):
|
|||
'The value of windows size must equal to LoggerHook.interval'
|
||||
return window_size
|
||||
elif window_size == 'epoch':
|
||||
return runner.inner_iter + 1
|
||||
return self._inner_iter + 1
|
||||
elif window_size == 'global':
|
||||
return runner.iter + 1
|
||||
else:
|
||||
|
@ -505,7 +510,7 @@ class LoggerHook(Hook):
|
|||
int: The current global iter or inner iter.
|
||||
"""
|
||||
if self.by_epoch and inner_iter:
|
||||
current_iter = runner.inner_iter + 1
|
||||
current_iter = self._inner_iter + 1
|
||||
else:
|
||||
current_iter = runner.iter + 1
|
||||
return current_iter
|
||||
|
|
|
@ -40,12 +40,14 @@ class NaiveVisualizationHook(Hook):
|
|||
def after_test_iter(
|
||||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Show or Write the predicted results.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
|
|
|
@ -58,6 +58,7 @@ class OptimizerHook(Hook):
|
|||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""All operations need to be finished after each training iteration.
|
||||
|
@ -69,12 +70,13 @@ class OptimizerHook(Hook):
|
|||
|
||||
- Compute the gradient of model parameters.
|
||||
|
||||
- Clip the gradidents of each parameters. (optional)
|
||||
- Clip the gradients of each parameter. (optional)
|
||||
|
||||
- Update model parameters with gradients.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. In order to keep this interface consistent
|
||||
with other hooks, we keep ``data_batch`` here.
|
||||
|
|
|
@ -10,19 +10,21 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
|||
|
||||
@HOOKS.register_module()
|
||||
class ParamSchedulerHook(Hook):
|
||||
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
|
||||
"""A hook to update some hyper-parameters in optimizer, e.g., learning rate
|
||||
and momentum."""
|
||||
|
||||
priority = 'LOW'
|
||||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Call step function for each scheduler after each iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. In order to keep this interface consistent
|
||||
with other hooks, we keep ``data_batch`` here.
|
||||
|
|
|
@ -70,9 +70,8 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
||||
from dataloader.
|
||||
"""
|
||||
self.runner._inner_iter = idx
|
||||
|
||||
self.runner.call_hook('before_train_iter', data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'before_train_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# outputs should be a dict containing one or multiple loss tensors
|
||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||
|
||||
|
@ -82,6 +81,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=self.runner.outputs)
|
||||
|
||||
|
@ -126,9 +126,6 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
if (self.runner.val_loop is not None and
|
||||
self.runner._iter % self.runner.val_loop.interval == 0):
|
||||
self.runner.val_loop.run()
|
||||
# reset inner_iter to 0 to ensure it counts as expected in
|
||||
# train loop
|
||||
self.runner._inner_iter = 0
|
||||
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
self.runner.call_hook('after_train')
|
||||
|
@ -141,7 +138,10 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
||||
from dataloader.
|
||||
"""
|
||||
self.runner.call_hook('before_train_iter', data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'before_train_iter',
|
||||
batch_idx=self.runner._iter,
|
||||
data_batch=data_batch)
|
||||
# outputs should be a dict containing loss tensor
|
||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||
|
||||
|
@ -151,10 +151,10 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
batch_idx=self.runner._iter,
|
||||
data_batch=data_batch,
|
||||
outputs=self.runner.outputs)
|
||||
self.runner._iter += 1
|
||||
self.runner._inner_iter += 1
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
|
@ -208,13 +208,16 @@ class ValLoop(BaseLoop):
|
|||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
||||
from dataloader.
|
||||
"""
|
||||
self.runner._inner_iter = idx
|
||||
self.runner.call_hook('before_val_iter', data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'before_val_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# outputs should be sequence of BaseDataSample
|
||||
outputs = self.runner.model(data_batch)
|
||||
self.evaluator.process(data_batch, outputs)
|
||||
self.runner.call_hook(
|
||||
'after_val_iter', data_batch=data_batch, outputs=outputs)
|
||||
'after_val_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs)
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
|
@ -263,10 +266,13 @@ class TestLoop(BaseLoop):
|
|||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
||||
from dataloader.
|
||||
"""
|
||||
self.runner._inner_iter = idx
|
||||
self.runner.call_hook('before_test_iter', data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# predictions should be sequence of BaseDataSample
|
||||
predictions = self.runner.model(data_batch)
|
||||
self.evaluator.process(data_batch, predictions)
|
||||
self.runner.call_hook(
|
||||
'after_test_iter', data_batch=data_batch, outputs=predictions)
|
||||
'after_test_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=predictions)
|
||||
|
|
|
@ -229,7 +229,6 @@ class Runner:
|
|||
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
self._inner_iter = 0
|
||||
|
||||
# lazy initialization
|
||||
training_related = [
|
||||
|
@ -400,11 +399,6 @@ class Runner:
|
|||
"""int: Current epoch."""
|
||||
return self._iter
|
||||
|
||||
@property
|
||||
def inner_iter(self):
|
||||
"""int: Current iteration."""
|
||||
return self._inner_iter
|
||||
|
||||
@property
|
||||
def launcher(self):
|
||||
"""str: Way to launcher multi processes."""
|
||||
|
|
|
@ -100,19 +100,20 @@ class TestCheckpointHook:
|
|||
runner = Mock()
|
||||
runner.work_dir = './tmp'
|
||||
runner.iter = 9
|
||||
batch_idx = 9
|
||||
runner.meta = dict()
|
||||
runner.model = Mock()
|
||||
|
||||
# by epoch is True
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_iter(runner)
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
assert runner.meta.get('hook_msgs', None) is None
|
||||
|
||||
# by epoch is False
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_iter(runner)
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
assert (runner.iter + 1) % 2 == 0
|
||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
||||
|
||||
|
@ -129,5 +130,5 @@ class TestCheckpointHook:
|
|||
checkpoint_hook = CheckpointHook(
|
||||
interval=2, by_epoch=False, max_keep_ckpts=1)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_iter(runner)
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
|
||||
|
|
|
@ -7,8 +7,8 @@ from mmengine.hooks import EmptyCacheHook
|
|||
class TestEmptyCacheHook:
|
||||
|
||||
def test_emtpy_cache_hook(self):
|
||||
Hook = EmptyCacheHook(True, True, True)
|
||||
Runner = Mock()
|
||||
Hook._after_iter(Runner)
|
||||
Hook._before_epoch(Runner)
|
||||
Hook._after_epoch(Runner)
|
||||
hook = EmptyCacheHook(True, True, True)
|
||||
runner = Mock()
|
||||
hook._after_iter(runner, 0)
|
||||
hook._before_epoch(runner)
|
||||
hook._after_epoch(runner)
|
||||
|
|
|
@ -136,11 +136,9 @@ class TestHook:
|
|||
|
||||
def test_every_n_inner_iters(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
|
||||
for i in range(100):
|
||||
runner.inner_iter = i
|
||||
return_val = hook.every_n_inner_iters(runner, 3)
|
||||
return_val = hook.every_n_inner_iters(i, 3)
|
||||
if (i + 1) % 3 == 0:
|
||||
assert return_val
|
||||
else:
|
||||
|
@ -162,15 +160,15 @@ class TestHook:
|
|||
runner = Mock()
|
||||
|
||||
# last inner iter
|
||||
runner.inner_iter = 1
|
||||
batch_idx = 1
|
||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||
return_val = hook.end_of_epoch(runner)
|
||||
return_val = hook.end_of_epoch(runner, batch_idx)
|
||||
assert return_val
|
||||
|
||||
# not the last inner iter
|
||||
runner.inner_iter = 0
|
||||
return_val = hook.end_of_epoch(runner)
|
||||
batch_idx = 0
|
||||
return_val = hook.end_of_epoch(runner, batch_idx)
|
||||
assert not return_val
|
||||
|
||||
def test_is_last_train_epoch(self):
|
||||
|
|
|
@ -7,23 +7,23 @@ from mmengine.hooks import IterTimerHook
|
|||
class TestIterTimerHook:
|
||||
|
||||
def test_before_epoch(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Hook._before_epoch(Runner)
|
||||
assert isinstance(Hook.t, float)
|
||||
hook = IterTimerHook()
|
||||
runner = Mock()
|
||||
hook._before_epoch(runner)
|
||||
assert isinstance(hook.t, float)
|
||||
|
||||
def test_before_iter(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Runner.log_buffer = dict()
|
||||
Hook._before_epoch(Runner)
|
||||
Hook._before_iter(Runner)
|
||||
Runner.message_hub.update_log.assert_called()
|
||||
hook = IterTimerHook()
|
||||
runner = Mock()
|
||||
runner.log_buffer = dict()
|
||||
hook._before_epoch(runner)
|
||||
hook._before_iter(runner, 0)
|
||||
runner.message_hub.update_log.assert_called()
|
||||
|
||||
def test_after_iter(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Runner.log_buffer = dict()
|
||||
Hook._before_epoch(Runner)
|
||||
Hook._after_iter(Runner)
|
||||
Runner.message_hub.update_log.assert_called()
|
||||
hook = IterTimerHook()
|
||||
runner = Mock()
|
||||
runner.log_buffer = dict()
|
||||
hook._before_epoch(runner)
|
||||
hook._after_iter(runner, 0)
|
||||
runner.message_hub.update_log.assert_called()
|
||||
|
|
|
@ -85,14 +85,15 @@ class TestLoggerHook:
|
|||
# Test LoggerHook by iter.
|
||||
runner = MagicMock()
|
||||
runner.iter = 10
|
||||
batch_idx = 5
|
||||
logger_hook = LoggerHook(by_epoch=False)
|
||||
logger_hook._log_train = MagicMock()
|
||||
logger_hook.after_train_iter(runner)
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
# `cur_iter=10+1`, which cannot be exact division by
|
||||
# `logger_hook.interval`
|
||||
logger_hook._log_train.assert_not_called()
|
||||
runner.iter = 9
|
||||
logger_hook.after_train_iter(runner)
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_called()
|
||||
|
||||
# Test LoggerHook by epoch.
|
||||
|
@ -100,19 +101,19 @@ class TestLoggerHook:
|
|||
logger_hook._log_train = MagicMock()
|
||||
# Only `runner.inner_iter` will work.
|
||||
runner.iter = 9
|
||||
runner.inner_iter = 10
|
||||
logger_hook.after_train_iter(runner)
|
||||
batch_idx = 10
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_not_called()
|
||||
runner.inner_iter = 9
|
||||
logger_hook.after_train_iter(runner)
|
||||
batch_idx = 9
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_called()
|
||||
|
||||
# Test end of the epoch.
|
||||
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
|
||||
logger_hook._log_train = MagicMock()
|
||||
runner.cur_dataloader = [0] * 5
|
||||
runner.inner_iter = 4
|
||||
logger_hook.after_train_iter(runner)
|
||||
batch_idx = 4
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_called()
|
||||
|
||||
# Test print exp_name
|
||||
|
@ -120,7 +121,7 @@ class TestLoggerHook:
|
|||
logger_hook = LoggerHook()
|
||||
runner.logger = MagicMock()
|
||||
logger_hook._log_train = MagicMock()
|
||||
logger_hook.after_train_iter(runner)
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
runner.logger.info.assert_called_with(
|
||||
f'Exp name: {runner.meta["exp_name"]}')
|
||||
|
||||
|
@ -137,6 +138,7 @@ class TestLoggerHook:
|
|||
runner.meta = dict(exp_name='retinanet')
|
||||
# Prepare LoggerHook
|
||||
logger_hook = LoggerHook(by_epoch=by_epoch)
|
||||
logger_hook._inner_iter = 1
|
||||
logger_hook.writer = MagicMock()
|
||||
logger_hook.time_sec_tot = 1000
|
||||
logger_hook.start_iter = 0
|
||||
|
@ -220,6 +222,7 @@ class TestLoggerHook:
|
|||
def test_get_window_size(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook()
|
||||
logger_hook._inner_iter = 1
|
||||
# Test get window size by name.
|
||||
assert logger_hook._get_window_size(runner, 'epoch') == 2
|
||||
assert logger_hook._get_window_size(runner, 'global') == 11
|
||||
|
@ -313,6 +316,7 @@ class TestLoggerHook:
|
|||
def test_get_iter(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook()
|
||||
logger_hook._inner_iter = 1
|
||||
# Get global iter when `inner_iter=False`
|
||||
iter = logger_hook._get_iter(runner)
|
||||
assert iter == 11
|
||||
|
@ -338,7 +342,6 @@ class TestLoggerHook:
|
|||
runner = MagicMock()
|
||||
runner.epoch = 1
|
||||
runner.cur_dataloader = [0] * 5
|
||||
runner.inner_iter = 1
|
||||
runner.iter = 10
|
||||
runner.train_loop.max_iters = 50
|
||||
logger = logging.getLogger()
|
||||
|
|
|
@ -11,9 +11,10 @@ class TestNaiveVisualizationHook:
|
|||
|
||||
def test_after_train_iter(self):
|
||||
naive_visualization_hook = NaiveVisualizationHook()
|
||||
Runner = Mock(iter=1)
|
||||
Runner.writer.add_image = Mock()
|
||||
runner = Mock(iter=1)
|
||||
runner.writer.add_image = Mock()
|
||||
inputs = torch.randn(1, 3, 15, 15)
|
||||
batch_idx = 10
|
||||
# test with normalize, resize, pad
|
||||
gt_datasamples = [
|
||||
BaseDataSample(
|
||||
|
@ -28,7 +29,7 @@ class TestNaiveVisualizationHook:
|
|||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
# test with resize, pad
|
||||
gt_datasamples = [
|
||||
|
@ -42,7 +43,7 @@ class TestNaiveVisualizationHook:
|
|||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
# test with only resize
|
||||
gt_datasamples = [
|
||||
|
@ -55,7 +56,7 @@ class TestNaiveVisualizationHook:
|
|||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
|
||||
# test with only pad
|
||||
|
@ -69,7 +70,7 @@ class TestNaiveVisualizationHook:
|
|||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
|
||||
# test no transform
|
||||
|
@ -80,5 +81,5 @@ class TestNaiveVisualizationHook:
|
|||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
|
|
|
@ -73,7 +73,7 @@ class TestOptimizerHook:
|
|||
wraps=optimizer_hook.detect_anomalous_parameters)
|
||||
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
|
||||
|
||||
optimizer_hook.after_train_iter(dummy_runner)
|
||||
optimizer_hook.after_train_iter(dummy_runner, 0)
|
||||
# assert the parameters of conv2 and conv3 are not in the
|
||||
# computational graph which is with x1.sum() as root.
|
||||
assert 'conv2.weight' in dummy_runner.logger.msg
|
||||
|
@ -89,7 +89,7 @@ class TestOptimizerHook:
|
|||
|
||||
dummy_runner.outputs['loss'] = model(x)[1].sum()
|
||||
dummy_runner.logger.msg = ''
|
||||
optimizer_hook.after_train_iter(dummy_runner)
|
||||
optimizer_hook.after_train_iter(dummy_runner, 0)
|
||||
# assert the parameters of conv3 are not in the computational graph
|
||||
assert 'conv3.weight' in dummy_runner.logger.msg
|
||||
assert 'conv3.bias' in dummy_runner.logger.msg
|
||||
|
@ -107,7 +107,7 @@ class TestOptimizerHook:
|
|||
dummy_runner.outputs['loss'].backward = Mock(
|
||||
wraps=dummy_runner.outputs['loss'].backward)
|
||||
|
||||
optimizer_hook.after_train_iter(dummy_runner)
|
||||
optimizer_hook.after_train_iter(dummy_runner, 0)
|
||||
|
||||
dummy_runner.optimizer.step.assert_called()
|
||||
dummy_runner.outputs['loss'].backward.assert_called()
|
||||
|
|
|
@ -7,21 +7,21 @@ from mmengine.hooks import ParamSchedulerHook
|
|||
class TestParamSchedulerHook:
|
||||
|
||||
def test_after_iter(self):
|
||||
Hook = ParamSchedulerHook()
|
||||
Runner = Mock()
|
||||
hook = ParamSchedulerHook()
|
||||
runner = Mock()
|
||||
scheduler = Mock()
|
||||
scheduler.step = Mock()
|
||||
scheduler.by_epoch = False
|
||||
Runner.param_schedulers = [scheduler]
|
||||
Hook.after_train_iter(Runner)
|
||||
runner.param_schedulers = [scheduler]
|
||||
hook.after_train_iter(runner, 0)
|
||||
scheduler.step.assert_called()
|
||||
|
||||
def test_after_epoch(self):
|
||||
Hook = ParamSchedulerHook()
|
||||
Runner = Mock()
|
||||
hook = ParamSchedulerHook()
|
||||
runner = Mock()
|
||||
scheduler = Mock()
|
||||
scheduler.step = Mock()
|
||||
scheduler.by_epoch = True
|
||||
Runner.param_schedulers = [scheduler]
|
||||
Hook.after_train_epoch(Runner)
|
||||
runner.param_schedulers = [scheduler]
|
||||
hook.after_train_epoch(runner)
|
||||
scheduler.step.assert_called()
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmengine.hooks import SyncBuffersHook
|
|||
class TestSyncBuffersHook:
|
||||
|
||||
def test_sync_buffers_hook(self):
|
||||
Runner = Mock()
|
||||
Runner.model = Mock()
|
||||
Hook = SyncBuffersHook()
|
||||
Hook._after_epoch(Runner)
|
||||
runner = Mock()
|
||||
runner.model = Mock()
|
||||
hook = SyncBuffersHook()
|
||||
hook._after_epoch(runner)
|
||||
|
|
|
@ -720,8 +720,8 @@ class TestRunner(TestCase):
|
|||
epoch_targets = [i for i in range(3)]
|
||||
iter_results = []
|
||||
iter_targets = [i for i in range(4 * 3)]
|
||||
inner_iter_results = []
|
||||
inner_iter_targets = [i for i in range(4)] * 3 # train and val
|
||||
batch_idx_results = []
|
||||
batch_idx_targets = [i for i in range(4)] * 3 # train and val
|
||||
|
||||
@HOOKS.register_module()
|
||||
class TestEpochHook(Hook):
|
||||
|
@ -729,9 +729,9 @@ class TestRunner(TestCase):
|
|||
def before_train_epoch(self, runner):
|
||||
epoch_results.append(runner.epoch)
|
||||
|
||||
def before_train_iter(self, runner, data_batch=None):
|
||||
def before_train_iter(self, runner, batch_idx, data_batch=None):
|
||||
iter_results.append(runner.iter)
|
||||
inner_iter_results.append(runner.inner_iter)
|
||||
batch_idx_results.append(batch_idx)
|
||||
|
||||
self.epoch_based_cfg.custom_hooks = [
|
||||
dict(type='TestEpochHook', priority=50)
|
||||
|
@ -746,7 +746,7 @@ class TestRunner(TestCase):
|
|||
self.assertEqual(result, target)
|
||||
for result, target, in zip(iter_results, iter_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(inner_iter_results, inner_iter_targets):
|
||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
time.sleep(1)
|
||||
|
@ -754,9 +754,9 @@ class TestRunner(TestCase):
|
|||
# 3. test iter and epoch counter of IterBasedTrainLoop
|
||||
epoch_results = []
|
||||
iter_results = []
|
||||
inner_iter_results = []
|
||||
batch_idx_results = []
|
||||
iter_targets = [i for i in range(12)]
|
||||
inner_iter_targets = [0, 1, 2, 3] * 3
|
||||
batch_idx_targets = [i for i in range(12)]
|
||||
|
||||
@HOOKS.register_module()
|
||||
class TestIterHook(Hook):
|
||||
|
@ -764,9 +764,9 @@ class TestRunner(TestCase):
|
|||
def before_train_epoch(self, runner):
|
||||
epoch_results.append(runner.epoch)
|
||||
|
||||
def before_train_iter(self, runner, data_batch=None):
|
||||
def before_train_iter(self, runner, batch_idx, data_batch=None):
|
||||
iter_results.append(runner.iter)
|
||||
inner_iter_results.append(runner.inner_iter)
|
||||
batch_idx_results.append(batch_idx)
|
||||
|
||||
self.iter_based_cfg.custom_hooks = [
|
||||
dict(type='TestIterHook', priority=50)
|
||||
|
@ -781,7 +781,7 @@ class TestRunner(TestCase):
|
|||
self.assertEqual(epoch_results[0], 0)
|
||||
for result, target, in zip(iter_results, iter_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(inner_iter_results, inner_iter_targets):
|
||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
def test_val(self):
|
||||
|
@ -1056,7 +1056,6 @@ class TestRunner(TestCase):
|
|||
runner.resume(path)
|
||||
self.assertEqual(runner.epoch, 0)
|
||||
self.assertEqual(runner.iter, 12)
|
||||
self.assertEqual(runner.inner_iter, 0)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
|
Loading…
Reference in New Issue