[Refactor] Add batch_idx to hook input. (#140)

* [Refactor] Add batch_idx to hook input.

* update
pull/139/head^2
RangiLyu 2022-03-29 11:40:38 +08:00 committed by GitHub
parent 563b4bad16
commit 9a61b389e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 186 additions and 128 deletions

View File

@ -20,31 +20,31 @@ class CheckpointHook(Hook):
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
Defaults to -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
Defaults to True.
save_param_scheduler (bool): Whether to save param_scheduler state_dict
in the checkpoint. It is usually used for resuming experiments.
Default: True.
Defaults to True.
out_dir (str, optional | Path): The root directory to save checkpoints.
If not specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``. For example,
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
``./work_dir/cur_exp``, then the ckpt will be saved in
``./tmp/cur_exp``. Deafule to None.
``./tmp/cur_exp``. Defaults to None.
max_keep_ckpts (int): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
Defaults to -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be
saved regardless of interval. Default: True.
saved regardless of interval. Defaults to True.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Defaults to None.
"""
out_dir: str
@ -177,12 +177,14 @@ class CheckpointHook(Hook):
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs=Optional[dict]) -> None:
"""Save the checkpoint and synchronize buffers after each iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model.

View File

@ -36,6 +36,7 @@ class EmptyCacheHook(Hook):
def _after_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[dict,
Sequence[BaseDataSample]]] = None,
@ -44,6 +45,7 @@ class EmptyCacheHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from model.

View File

@ -164,41 +164,57 @@ class Hook:
"""
self._after_epoch(runner, mode='test')
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self._before_iter(runner, data_batch=data_batch, mode='train')
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
def before_val_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each validation iteration.
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self._before_iter(runner, data_batch=data_batch, mode='val')
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='val')
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
def before_test_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each test iteration.
Args:
runner (Runner): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self._before_iter(runner, data_batch=data_batch, mode='test')
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='test')
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""All subclasses should override this method, if they need any
@ -206,16 +222,22 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='train')
runner,
batch_idx=batch_idx,
data_batch=data_batch,
outputs=outputs,
mode='train')
def after_val_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) \
-> None:
@ -224,17 +246,23 @@ class Hook:
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from
model. Defaults to None.
"""
self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='val')
runner,
batch_idx=batch_idx,
data_batch=data_batch,
outputs=outputs,
mode='val')
def after_test_iter(
self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
@ -242,13 +270,18 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='test')
runner,
batch_idx=batch_idx,
data_batch=data_batch,
outputs=outputs,
mode='test')
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
@ -274,6 +307,7 @@ class Hook:
def _before_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
@ -282,6 +316,7 @@ class Hook:
Args:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
@ -290,6 +325,7 @@ class Hook:
def _after_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[Sequence[BaseDataSample],
dict]] = None,
@ -300,6 +336,7 @@ class Hook:
Args:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
@ -321,12 +358,12 @@ class Hook:
"""
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n: int) -> bool:
def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
"""Test whether current inner iteration can be evenly divided by n.
Args:
runner (Runner): The runner of the training, validation or testing
process.
inner_iter (int): Current inner_iter of the training, validation
or testing loop.
n (int): Whether current inner iteration can be evenly
divided by n.
@ -334,7 +371,7 @@ class Hook:
bool: Whether current inner iteration can be evenly
divided by n.
"""
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
return (inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n: int) -> bool:
"""Test whether current iteration can be evenly divided by n.
@ -350,18 +387,19 @@ class Hook:
"""
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner) -> bool:
def end_of_epoch(self, runner, batch_idx: int) -> bool:
"""Check whether the current iteration reaches the last iteration of
current dataloader.
Args:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
Returns:
bool: Whether reaches the end of current epoch or not.
"""
return runner.inner_iter + 1 == len(runner.cur_dataloader)
return batch_idx + 1 == len(runner.cur_dataloader)
def is_last_train_epoch(self, runner) -> bool:
"""Test whether current epoch is the last train epoch.

View File

@ -13,7 +13,7 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
class IterTimerHook(Hook):
"""A hook that logs the time spent during iteration.
Eg. ``data_time`` for loading data and ``time`` for a model train step.
E.g. ``data_time`` for loading data and ``time`` for a model train step.
"""
priority = 'NORMAL'
@ -29,12 +29,14 @@ class IterTimerHook(Hook):
def _before_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
mode: str = 'train') -> None:
"""Logging time for loading data and update the time flag.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
@ -45,15 +47,16 @@ class IterTimerHook(Hook):
def _after_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None,
mode: str = 'train') \
-> None:
outputs: Optional[Union[dict,
Sequence[BaseDataSample]]] = None,
mode: str = 'train') -> None:
"""Logging time for a iteration and update the time flag.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults

View File

@ -121,6 +121,7 @@ class LoggerHook(Hook):
keep_local=True,
file_client_args=None,
):
self._inner_iter = 0
self.by_epoch = by_epoch
self.interval = interval
self.custom_keys = custom_keys if custom_keys is not None else dict()
@ -174,27 +175,31 @@ class LoggerHook(Hook):
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Record training logs.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self._inner_iter = batch_idx
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner)):
self.by_epoch and self.end_of_epoch(runner, batch_idx)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
if self.by_epoch and self.every_n_inner_iters(batch_idx,
self.interval):
self._log_train(runner)
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
self._log_train(runner)
elif self.end_of_epoch(runner) and not self.ignore_last:
elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
# `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
@ -346,7 +351,7 @@ class LoggerHook(Hook):
'The value of windows size must equal to LoggerHook.interval'
return window_size
elif window_size == 'epoch':
return runner.inner_iter + 1
return self._inner_iter + 1
elif window_size == 'global':
return runner.iter + 1
else:
@ -505,7 +510,7 @@ class LoggerHook(Hook):
int: The current global iter or inner iter.
"""
if self.by_epoch and inner_iter:
current_iter = runner.inner_iter + 1
current_iter = self._inner_iter + 1
else:
current_iter = runner.iter + 1
return current_iter

View File

@ -40,12 +40,14 @@ class NaiveVisualizationHook(Hook):
def after_test_iter(
self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Show or Write the predicted results.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.

View File

@ -58,6 +58,7 @@ class OptimizerHook(Hook):
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""All operations need to be finished after each training iteration.
@ -69,12 +70,13 @@ class OptimizerHook(Hook):
- Compute the gradient of model parameters.
- Clip the gradidents of each parameters. (optional)
- Clip the gradients of each parameter. (optional)
- Update model parameters with gradients.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.

View File

@ -10,19 +10,21 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module()
class ParamSchedulerHook(Hook):
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
"""A hook to update some hyper-parameters in optimizer, e.g., learning rate
and momentum."""
priority = 'LOW'
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Call step function for each scheduler after each iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.

View File

@ -70,9 +70,8 @@ class EpochBasedTrainLoop(BaseLoop):
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
from dataloader.
"""
self.runner._inner_iter = idx
self.runner.call_hook('before_train_iter', data_batch=data_batch)
self.runner.call_hook(
'before_train_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be a dict containing one or multiple loss tensors
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
@ -82,6 +81,7 @@ class EpochBasedTrainLoop(BaseLoop):
self.runner.call_hook(
'after_train_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=self.runner.outputs)
@ -126,9 +126,6 @@ class IterBasedTrainLoop(BaseLoop):
if (self.runner.val_loop is not None and
self.runner._iter % self.runner.val_loop.interval == 0):
self.runner.val_loop.run()
# reset inner_iter to 0 to ensure it counts as expected in
# train loop
self.runner._inner_iter = 0
self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')
@ -141,7 +138,10 @@ class IterBasedTrainLoop(BaseLoop):
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
from dataloader.
"""
self.runner.call_hook('before_train_iter', data_batch=data_batch)
self.runner.call_hook(
'before_train_iter',
batch_idx=self.runner._iter,
data_batch=data_batch)
# outputs should be a dict containing loss tensor
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
@ -151,10 +151,10 @@ class IterBasedTrainLoop(BaseLoop):
self.runner.call_hook(
'after_train_iter',
batch_idx=self.runner._iter,
data_batch=data_batch,
outputs=self.runner.outputs)
self.runner._iter += 1
self.runner._inner_iter += 1
@LOOPS.register_module()
@ -208,13 +208,16 @@ class ValLoop(BaseLoop):
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
from dataloader.
"""
self.runner._inner_iter = idx
self.runner.call_hook('before_val_iter', data_batch=data_batch)
self.runner.call_hook(
'before_val_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataSample
outputs = self.runner.model(data_batch)
self.evaluator.process(data_batch, outputs)
self.runner.call_hook(
'after_val_iter', data_batch=data_batch, outputs=outputs)
'after_val_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
@LOOPS.register_module()
@ -263,10 +266,13 @@ class TestLoop(BaseLoop):
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
from dataloader.
"""
self.runner._inner_iter = idx
self.runner.call_hook('before_test_iter', data_batch=data_batch)
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataSample
predictions = self.runner.model(data_batch)
self.evaluator.process(data_batch, predictions)
self.runner.call_hook(
'after_test_iter', data_batch=data_batch, outputs=predictions)
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=predictions)

View File

@ -229,7 +229,6 @@ class Runner:
self._epoch = 0
self._iter = 0
self._inner_iter = 0
# lazy initialization
training_related = [
@ -400,11 +399,6 @@ class Runner:
"""int: Current epoch."""
return self._iter
@property
def inner_iter(self):
"""int: Current iteration."""
return self._inner_iter
@property
def launcher(self):
"""str: Way to launcher multi processes."""

View File

@ -100,19 +100,20 @@ class TestCheckpointHook:
runner = Mock()
runner.work_dir = './tmp'
runner.iter = 9
batch_idx = 9
runner.meta = dict()
runner.model = Mock()
# by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert runner.meta.get('hook_msgs', None) is None
# by epoch is False
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert (runner.iter + 1) % 2 == 0
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
@ -129,5 +130,5 @@ class TestCheckpointHook:
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, max_keep_ckpts=1)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')

View File

@ -7,8 +7,8 @@ from mmengine.hooks import EmptyCacheHook
class TestEmptyCacheHook:
def test_emtpy_cache_hook(self):
Hook = EmptyCacheHook(True, True, True)
Runner = Mock()
Hook._after_iter(Runner)
Hook._before_epoch(Runner)
Hook._after_epoch(Runner)
hook = EmptyCacheHook(True, True, True)
runner = Mock()
hook._after_iter(runner, 0)
hook._before_epoch(runner)
hook._after_epoch(runner)

View File

@ -136,11 +136,9 @@ class TestHook:
def test_every_n_inner_iters(self):
hook = Hook()
runner = Mock()
for i in range(100):
runner.inner_iter = i
return_val = hook.every_n_inner_iters(runner, 3)
return_val = hook.every_n_inner_iters(i, 3)
if (i + 1) % 3 == 0:
assert return_val
else:
@ -162,15 +160,15 @@ class TestHook:
runner = Mock()
# last inner iter
runner.inner_iter = 1
batch_idx = 1
runner.cur_dataloader.__len__ = Mock(return_value=2)
runner.cur_dataloader.__len__ = Mock(return_value=2)
return_val = hook.end_of_epoch(runner)
return_val = hook.end_of_epoch(runner, batch_idx)
assert return_val
# not the last inner iter
runner.inner_iter = 0
return_val = hook.end_of_epoch(runner)
batch_idx = 0
return_val = hook.end_of_epoch(runner, batch_idx)
assert not return_val
def test_is_last_train_epoch(self):

View File

@ -7,23 +7,23 @@ from mmengine.hooks import IterTimerHook
class TestIterTimerHook:
def test_before_epoch(self):
Hook = IterTimerHook()
Runner = Mock()
Hook._before_epoch(Runner)
assert isinstance(Hook.t, float)
hook = IterTimerHook()
runner = Mock()
hook._before_epoch(runner)
assert isinstance(hook.t, float)
def test_before_iter(self):
Hook = IterTimerHook()
Runner = Mock()
Runner.log_buffer = dict()
Hook._before_epoch(Runner)
Hook._before_iter(Runner)
Runner.message_hub.update_log.assert_called()
hook = IterTimerHook()
runner = Mock()
runner.log_buffer = dict()
hook._before_epoch(runner)
hook._before_iter(runner, 0)
runner.message_hub.update_log.assert_called()
def test_after_iter(self):
Hook = IterTimerHook()
Runner = Mock()
Runner.log_buffer = dict()
Hook._before_epoch(Runner)
Hook._after_iter(Runner)
Runner.message_hub.update_log.assert_called()
hook = IterTimerHook()
runner = Mock()
runner.log_buffer = dict()
hook._before_epoch(runner)
hook._after_iter(runner, 0)
runner.message_hub.update_log.assert_called()

View File

@ -85,14 +85,15 @@ class TestLoggerHook:
# Test LoggerHook by iter.
runner = MagicMock()
runner.iter = 10
batch_idx = 5
logger_hook = LoggerHook(by_epoch=False)
logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner)
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
# `cur_iter=10+1`, which cannot be exact division by
# `logger_hook.interval`
logger_hook._log_train.assert_not_called()
runner.iter = 9
logger_hook.after_train_iter(runner)
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
# Test LoggerHook by epoch.
@ -100,19 +101,19 @@ class TestLoggerHook:
logger_hook._log_train = MagicMock()
# Only `runner.inner_iter` will work.
runner.iter = 9
runner.inner_iter = 10
logger_hook.after_train_iter(runner)
batch_idx = 10
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_not_called()
runner.inner_iter = 9
logger_hook.after_train_iter(runner)
batch_idx = 9
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
# Test end of the epoch.
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
logger_hook._log_train = MagicMock()
runner.cur_dataloader = [0] * 5
runner.inner_iter = 4
logger_hook.after_train_iter(runner)
batch_idx = 4
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
# Test print exp_name
@ -120,7 +121,7 @@ class TestLoggerHook:
logger_hook = LoggerHook()
runner.logger = MagicMock()
logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner)
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
runner.logger.info.assert_called_with(
f'Exp name: {runner.meta["exp_name"]}')
@ -137,6 +138,7 @@ class TestLoggerHook:
runner.meta = dict(exp_name='retinanet')
# Prepare LoggerHook
logger_hook = LoggerHook(by_epoch=by_epoch)
logger_hook._inner_iter = 1
logger_hook.writer = MagicMock()
logger_hook.time_sec_tot = 1000
logger_hook.start_iter = 0
@ -220,6 +222,7 @@ class TestLoggerHook:
def test_get_window_size(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
logger_hook._inner_iter = 1
# Test get window size by name.
assert logger_hook._get_window_size(runner, 'epoch') == 2
assert logger_hook._get_window_size(runner, 'global') == 11
@ -313,6 +316,7 @@ class TestLoggerHook:
def test_get_iter(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
logger_hook._inner_iter = 1
# Get global iter when `inner_iter=False`
iter = logger_hook._get_iter(runner)
assert iter == 11
@ -338,7 +342,6 @@ class TestLoggerHook:
runner = MagicMock()
runner.epoch = 1
runner.cur_dataloader = [0] * 5
runner.inner_iter = 1
runner.iter = 10
runner.train_loop.max_iters = 50
logger = logging.getLogger()

View File

@ -11,9 +11,10 @@ class TestNaiveVisualizationHook:
def test_after_train_iter(self):
naive_visualization_hook = NaiveVisualizationHook()
Runner = Mock(iter=1)
Runner.writer.add_image = Mock()
runner = Mock(iter=1)
runner.writer.add_image = Mock()
inputs = torch.randn(1, 3, 15, 15)
batch_idx = 10
# test with normalize, resize, pad
gt_datasamples = [
BaseDataSample(
@ -28,7 +29,7 @@ class TestNaiveVisualizationHook:
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with resize, pad
gt_datasamples = [
@ -42,7 +43,7 @@ class TestNaiveVisualizationHook:
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with only resize
gt_datasamples = [
@ -55,7 +56,7 @@ class TestNaiveVisualizationHook:
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with only pad
@ -69,7 +70,7 @@ class TestNaiveVisualizationHook:
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test no transform
@ -80,5 +81,5 @@ class TestNaiveVisualizationHook:
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)

View File

@ -73,7 +73,7 @@ class TestOptimizerHook:
wraps=optimizer_hook.detect_anomalous_parameters)
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
optimizer_hook.after_train_iter(dummy_runner)
optimizer_hook.after_train_iter(dummy_runner, 0)
# assert the parameters of conv2 and conv3 are not in the
# computational graph which is with x1.sum() as root.
assert 'conv2.weight' in dummy_runner.logger.msg
@ -89,7 +89,7 @@ class TestOptimizerHook:
dummy_runner.outputs['loss'] = model(x)[1].sum()
dummy_runner.logger.msg = ''
optimizer_hook.after_train_iter(dummy_runner)
optimizer_hook.after_train_iter(dummy_runner, 0)
# assert the parameters of conv3 are not in the computational graph
assert 'conv3.weight' in dummy_runner.logger.msg
assert 'conv3.bias' in dummy_runner.logger.msg
@ -107,7 +107,7 @@ class TestOptimizerHook:
dummy_runner.outputs['loss'].backward = Mock(
wraps=dummy_runner.outputs['loss'].backward)
optimizer_hook.after_train_iter(dummy_runner)
optimizer_hook.after_train_iter(dummy_runner, 0)
dummy_runner.optimizer.step.assert_called()
dummy_runner.outputs['loss'].backward.assert_called()

View File

@ -7,21 +7,21 @@ from mmengine.hooks import ParamSchedulerHook
class TestParamSchedulerHook:
def test_after_iter(self):
Hook = ParamSchedulerHook()
Runner = Mock()
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = False
Runner.param_schedulers = [scheduler]
Hook.after_train_iter(Runner)
runner.param_schedulers = [scheduler]
hook.after_train_iter(runner, 0)
scheduler.step.assert_called()
def test_after_epoch(self):
Hook = ParamSchedulerHook()
Runner = Mock()
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = True
Runner.param_schedulers = [scheduler]
Hook.after_train_epoch(Runner)
runner.param_schedulers = [scheduler]
hook.after_train_epoch(runner)
scheduler.step.assert_called()

View File

@ -7,7 +7,7 @@ from mmengine.hooks import SyncBuffersHook
class TestSyncBuffersHook:
def test_sync_buffers_hook(self):
Runner = Mock()
Runner.model = Mock()
Hook = SyncBuffersHook()
Hook._after_epoch(Runner)
runner = Mock()
runner.model = Mock()
hook = SyncBuffersHook()
hook._after_epoch(runner)

View File

@ -720,8 +720,8 @@ class TestRunner(TestCase):
epoch_targets = [i for i in range(3)]
iter_results = []
iter_targets = [i for i in range(4 * 3)]
inner_iter_results = []
inner_iter_targets = [i for i in range(4)] * 3 # train and val
batch_idx_results = []
batch_idx_targets = [i for i in range(4)] * 3 # train and val
@HOOKS.register_module()
class TestEpochHook(Hook):
@ -729,9 +729,9 @@ class TestRunner(TestCase):
def before_train_epoch(self, runner):
epoch_results.append(runner.epoch)
def before_train_iter(self, runner, data_batch=None):
def before_train_iter(self, runner, batch_idx, data_batch=None):
iter_results.append(runner.iter)
inner_iter_results.append(runner.inner_iter)
batch_idx_results.append(batch_idx)
self.epoch_based_cfg.custom_hooks = [
dict(type='TestEpochHook', priority=50)
@ -746,7 +746,7 @@ class TestRunner(TestCase):
self.assertEqual(result, target)
for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target)
for result, target, in zip(inner_iter_results, inner_iter_targets):
for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target)
time.sleep(1)
@ -754,9 +754,9 @@ class TestRunner(TestCase):
# 3. test iter and epoch counter of IterBasedTrainLoop
epoch_results = []
iter_results = []
inner_iter_results = []
batch_idx_results = []
iter_targets = [i for i in range(12)]
inner_iter_targets = [0, 1, 2, 3] * 3
batch_idx_targets = [i for i in range(12)]
@HOOKS.register_module()
class TestIterHook(Hook):
@ -764,9 +764,9 @@ class TestRunner(TestCase):
def before_train_epoch(self, runner):
epoch_results.append(runner.epoch)
def before_train_iter(self, runner, data_batch=None):
def before_train_iter(self, runner, batch_idx, data_batch=None):
iter_results.append(runner.iter)
inner_iter_results.append(runner.inner_iter)
batch_idx_results.append(batch_idx)
self.iter_based_cfg.custom_hooks = [
dict(type='TestIterHook', priority=50)
@ -781,7 +781,7 @@ class TestRunner(TestCase):
self.assertEqual(epoch_results[0], 0)
for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target)
for result, target, in zip(inner_iter_results, inner_iter_targets):
for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target)
def test_val(self):
@ -1056,7 +1056,6 @@ class TestRunner(TestCase):
runner.resume(path)
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12)
self.assertEqual(runner.inner_iter, 0)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)