mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Add batch_idx to hook input. (#140)
* [Refactor] Add batch_idx to hook input. * update
This commit is contained in:
parent
563b4bad16
commit
9a61b389e7
@ -20,31 +20,31 @@ class CheckpointHook(Hook):
|
|||||||
Args:
|
Args:
|
||||||
interval (int): The saving period. If ``by_epoch=True``, interval
|
interval (int): The saving period. If ``by_epoch=True``, interval
|
||||||
indicates epochs, otherwise it indicates iterations.
|
indicates epochs, otherwise it indicates iterations.
|
||||||
Default: -1, which means "never".
|
Defaults to -1, which means "never".
|
||||||
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
||||||
Default: True.
|
Default: True.
|
||||||
save_optimizer (bool): Whether to save optimizer state_dict in the
|
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||||||
checkpoint. It is usually used for resuming experiments.
|
checkpoint. It is usually used for resuming experiments.
|
||||||
Default: True.
|
Defaults to True.
|
||||||
save_param_scheduler (bool): Whether to save param_scheduler state_dict
|
save_param_scheduler (bool): Whether to save param_scheduler state_dict
|
||||||
in the checkpoint. It is usually used for resuming experiments.
|
in the checkpoint. It is usually used for resuming experiments.
|
||||||
Default: True.
|
Defaults to True.
|
||||||
out_dir (str, optional | Path): The root directory to save checkpoints.
|
out_dir (str, optional | Path): The root directory to save checkpoints.
|
||||||
If not specified, ``runner.work_dir`` will be used by default. If
|
If not specified, ``runner.work_dir`` will be used by default. If
|
||||||
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
||||||
and the last level directory of ``runner.work_dir``. For example,
|
and the last level directory of ``runner.work_dir``. For example,
|
||||||
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
|
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
|
||||||
``./work_dir/cur_exp``, then the ckpt will be saved in
|
``./work_dir/cur_exp``, then the ckpt will be saved in
|
||||||
``./tmp/cur_exp``. Deafule to None.
|
``./tmp/cur_exp``. Defaults to None.
|
||||||
max_keep_ckpts (int): The maximum checkpoints to keep.
|
max_keep_ckpts (int): The maximum checkpoints to keep.
|
||||||
In some cases we want only the latest few checkpoints and would
|
In some cases we want only the latest few checkpoints and would
|
||||||
like to delete old ones to save the disk space.
|
like to delete old ones to save the disk space.
|
||||||
Default: -1, which means unlimited.
|
Defaults to -1, which means unlimited.
|
||||||
save_last (bool): Whether to force the last checkpoint to be
|
save_last (bool): Whether to force the last checkpoint to be
|
||||||
saved regardless of interval. Default: True.
|
saved regardless of interval. Defaults to True.
|
||||||
file_client_args (dict, optional): Arguments to instantiate a
|
file_client_args (dict, optional): Arguments to instantiate a
|
||||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||||
Default: None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
out_dir: str
|
out_dir: str
|
||||||
|
|
||||||
@ -177,12 +177,14 @@ class CheckpointHook(Hook):
|
|||||||
|
|
||||||
def after_train_iter(self,
|
def after_train_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs=Optional[dict]) -> None:
|
outputs=Optional[dict]) -> None:
|
||||||
"""Save the checkpoint and synchronize buffers after each iteration.
|
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (dict, optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
|
@ -36,6 +36,7 @@ class EmptyCacheHook(Hook):
|
|||||||
|
|
||||||
def _after_iter(self,
|
def _after_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Union[dict,
|
outputs: Optional[Union[dict,
|
||||||
Sequence[BaseDataSample]]] = None,
|
Sequence[BaseDataSample]]] = None,
|
||||||
@ -44,6 +45,7 @@ class EmptyCacheHook(Hook):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (dict or sequence, optional): Outputs from model.
|
outputs (dict or sequence, optional): Outputs from model.
|
||||||
|
@ -164,41 +164,57 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self._after_epoch(runner, mode='test')
|
self._after_epoch(runner, mode='test')
|
||||||
|
|
||||||
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
def before_train_iter(self,
|
||||||
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: DATA_BATCH = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each training iteration.
|
operations before each training iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self._before_iter(runner, data_batch=data_batch, mode='train')
|
self._before_iter(
|
||||||
|
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
|
||||||
|
|
||||||
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
def before_val_iter(self,
|
||||||
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: DATA_BATCH = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each validation iteration.
|
operations before each validation iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the validation process.
|
runner (Runner): The runner of the validation process.
|
||||||
|
batch_idx (int): The index of the current batch in the val loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self._before_iter(runner, data_batch=data_batch, mode='val')
|
self._before_iter(
|
||||||
|
runner, batch_idx=batch_idx, data_batch=data_batch, mode='val')
|
||||||
|
|
||||||
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
def before_test_iter(self,
|
||||||
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: DATA_BATCH = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each test iteration.
|
operations before each test iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the testing process.
|
runner (Runner): The runner of the testing process.
|
||||||
|
batch_idx (int): The index of the current batch in the test loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self._before_iter(runner, data_batch=data_batch, mode='test')
|
self._before_iter(
|
||||||
|
runner, batch_idx=batch_idx, data_batch=data_batch, mode='test')
|
||||||
|
|
||||||
def after_train_iter(self,
|
def after_train_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[dict] = None) -> None:
|
outputs: Optional[dict] = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -206,16 +222,22 @@ class Hook:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
outputs (dict, optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
self._after_iter(
|
self._after_iter(
|
||||||
runner, data_batch=data_batch, outputs=outputs, mode='train')
|
runner,
|
||||||
|
batch_idx=batch_idx,
|
||||||
|
data_batch=data_batch,
|
||||||
|
outputs=outputs,
|
||||||
|
mode='train')
|
||||||
|
|
||||||
def after_val_iter(self,
|
def after_val_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) \
|
outputs: Optional[Sequence[BaseDataSample]] = None) \
|
||||||
-> None:
|
-> None:
|
||||||
@ -224,17 +246,23 @@ class Hook:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the validation process.
|
runner (Runner): The runner of the validation process.
|
||||||
|
batch_idx (int): The index of the current batch in the val loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
outputs (dict or sequence, optional): Outputs from
|
outputs (dict or sequence, optional): Outputs from
|
||||||
model. Defaults to None.
|
model. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self._after_iter(
|
self._after_iter(
|
||||||
runner, data_batch=data_batch, outputs=outputs, mode='val')
|
runner,
|
||||||
|
batch_idx=batch_idx,
|
||||||
|
data_batch=data_batch,
|
||||||
|
outputs=outputs,
|
||||||
|
mode='val')
|
||||||
|
|
||||||
def after_test_iter(
|
def after_test_iter(
|
||||||
self,
|
self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -242,13 +270,18 @@ class Hook:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the test loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
outputs (dict, optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
self._after_iter(
|
self._after_iter(
|
||||||
runner, data_batch=data_batch, outputs=outputs, mode='test')
|
runner,
|
||||||
|
batch_idx=batch_idx,
|
||||||
|
data_batch=data_batch,
|
||||||
|
outputs=outputs,
|
||||||
|
mode='test')
|
||||||
|
|
||||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -274,6 +307,7 @@ class Hook:
|
|||||||
|
|
||||||
def _before_iter(self,
|
def _before_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
mode: str = 'train') -> None:
|
mode: str = 'train') -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -282,6 +316,7 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training, validation or testing
|
runner (Runner): The runner of the training, validation or testing
|
||||||
process.
|
process.
|
||||||
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
@ -290,6 +325,7 @@ class Hook:
|
|||||||
|
|
||||||
def _after_iter(self,
|
def _after_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Union[Sequence[BaseDataSample],
|
outputs: Optional[Union[Sequence[BaseDataSample],
|
||||||
dict]] = None,
|
dict]] = None,
|
||||||
@ -300,6 +336,7 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training, validation or testing
|
runner (Runner): The runner of the training, validation or testing
|
||||||
process.
|
process.
|
||||||
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||||
@ -321,12 +358,12 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||||
|
|
||||||
def every_n_inner_iters(self, runner, n: int) -> bool:
|
def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
|
||||||
"""Test whether current inner iteration can be evenly divided by n.
|
"""Test whether current inner iteration can be evenly divided by n.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training, validation or testing
|
inner_iter (int): Current inner_iter of the training, validation
|
||||||
process.
|
or testing loop.
|
||||||
n (int): Whether current inner iteration can be evenly
|
n (int): Whether current inner iteration can be evenly
|
||||||
divided by n.
|
divided by n.
|
||||||
|
|
||||||
@ -334,7 +371,7 @@ class Hook:
|
|||||||
bool: Whether current inner iteration can be evenly
|
bool: Whether current inner iteration can be evenly
|
||||||
divided by n.
|
divided by n.
|
||||||
"""
|
"""
|
||||||
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
|
return (inner_iter + 1) % n == 0 if n > 0 else False
|
||||||
|
|
||||||
def every_n_iters(self, runner, n: int) -> bool:
|
def every_n_iters(self, runner, n: int) -> bool:
|
||||||
"""Test whether current iteration can be evenly divided by n.
|
"""Test whether current iteration can be evenly divided by n.
|
||||||
@ -350,18 +387,19 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
return (runner.iter + 1) % n == 0 if n > 0 else False
|
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||||
|
|
||||||
def end_of_epoch(self, runner) -> bool:
|
def end_of_epoch(self, runner, batch_idx: int) -> bool:
|
||||||
"""Check whether the current iteration reaches the last iteration of
|
"""Check whether the current iteration reaches the last iteration of
|
||||||
current dataloader.
|
current dataloader.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training, validation or testing
|
runner (Runner): The runner of the training, validation or testing
|
||||||
process.
|
process.
|
||||||
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether reaches the end of current epoch or not.
|
bool: Whether reaches the end of current epoch or not.
|
||||||
"""
|
"""
|
||||||
return runner.inner_iter + 1 == len(runner.cur_dataloader)
|
return batch_idx + 1 == len(runner.cur_dataloader)
|
||||||
|
|
||||||
def is_last_train_epoch(self, runner) -> bool:
|
def is_last_train_epoch(self, runner) -> bool:
|
||||||
"""Test whether current epoch is the last train epoch.
|
"""Test whether current epoch is the last train epoch.
|
||||||
|
@ -13,7 +13,7 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
|||||||
class IterTimerHook(Hook):
|
class IterTimerHook(Hook):
|
||||||
"""A hook that logs the time spent during iteration.
|
"""A hook that logs the time spent during iteration.
|
||||||
|
|
||||||
Eg. ``data_time`` for loading data and ``time`` for a model train step.
|
E.g. ``data_time`` for loading data and ``time`` for a model train step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
priority = 'NORMAL'
|
priority = 'NORMAL'
|
||||||
@ -29,12 +29,14 @@ class IterTimerHook(Hook):
|
|||||||
|
|
||||||
def _before_iter(self,
|
def _before_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
mode: str = 'train') -> None:
|
mode: str = 'train') -> None:
|
||||||
"""Logging time for loading data and update the time flag.
|
"""Logging time for loading data and update the time flag.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
@ -45,15 +47,16 @@ class IterTimerHook(Hook):
|
|||||||
|
|
||||||
def _after_iter(self,
|
def _after_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs:
|
outputs: Optional[Union[dict,
|
||||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None,
|
Sequence[BaseDataSample]]] = None,
|
||||||
mode: str = 'train') \
|
mode: str = 'train') -> None:
|
||||||
-> None:
|
|
||||||
"""Logging time for a iteration and update the time flag.
|
"""Logging time for a iteration and update the time flag.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (dict or sequence, optional): Outputs from model. Defaults
|
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||||
|
@ -121,6 +121,7 @@ class LoggerHook(Hook):
|
|||||||
keep_local=True,
|
keep_local=True,
|
||||||
file_client_args=None,
|
file_client_args=None,
|
||||||
):
|
):
|
||||||
|
self._inner_iter = 0
|
||||||
self.by_epoch = by_epoch
|
self.by_epoch = by_epoch
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.custom_keys = custom_keys if custom_keys is not None else dict()
|
self.custom_keys = custom_keys if custom_keys is not None else dict()
|
||||||
@ -174,27 +175,31 @@ class LoggerHook(Hook):
|
|||||||
|
|
||||||
def after_train_iter(self,
|
def after_train_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[dict] = None) -> None:
|
outputs: Optional[dict] = None) -> None:
|
||||||
"""Record training logs.
|
"""Record training logs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||||
dataloader. Defaults to None.
|
dataloader. Defaults to None.
|
||||||
outputs (dict, optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
self._inner_iter = batch_idx
|
||||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||||
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
||||||
self.by_epoch and self.end_of_epoch(runner)):
|
self.by_epoch and self.end_of_epoch(runner, batch_idx)):
|
||||||
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
||||||
runner.logger.info(exp_info)
|
runner.logger.info(exp_info)
|
||||||
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
|
if self.by_epoch and self.every_n_inner_iters(batch_idx,
|
||||||
|
self.interval):
|
||||||
self._log_train(runner)
|
self._log_train(runner)
|
||||||
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
||||||
self._log_train(runner)
|
self._log_train(runner)
|
||||||
elif self.end_of_epoch(runner) and not self.ignore_last:
|
elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
|
||||||
# `runner.max_iters` may not be divisible by `self.interval`. if
|
# `runner.max_iters` may not be divisible by `self.interval`. if
|
||||||
# `self.ignore_last==True`, the log of remaining iterations will
|
# `self.ignore_last==True`, the log of remaining iterations will
|
||||||
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
|
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
|
||||||
@ -346,7 +351,7 @@ class LoggerHook(Hook):
|
|||||||
'The value of windows size must equal to LoggerHook.interval'
|
'The value of windows size must equal to LoggerHook.interval'
|
||||||
return window_size
|
return window_size
|
||||||
elif window_size == 'epoch':
|
elif window_size == 'epoch':
|
||||||
return runner.inner_iter + 1
|
return self._inner_iter + 1
|
||||||
elif window_size == 'global':
|
elif window_size == 'global':
|
||||||
return runner.iter + 1
|
return runner.iter + 1
|
||||||
else:
|
else:
|
||||||
@ -505,7 +510,7 @@ class LoggerHook(Hook):
|
|||||||
int: The current global iter or inner iter.
|
int: The current global iter or inner iter.
|
||||||
"""
|
"""
|
||||||
if self.by_epoch and inner_iter:
|
if self.by_epoch and inner_iter:
|
||||||
current_iter = runner.inner_iter + 1
|
current_iter = self._inner_iter + 1
|
||||||
else:
|
else:
|
||||||
current_iter = runner.iter + 1
|
current_iter = runner.iter + 1
|
||||||
return current_iter
|
return current_iter
|
||||||
|
@ -40,12 +40,14 @@ class NaiveVisualizationHook(Hook):
|
|||||||
def after_test_iter(
|
def after_test_iter(
|
||||||
self,
|
self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""Show or Write the predicted results.
|
"""Show or Write the predicted results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the test loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||||
|
@ -58,6 +58,7 @@ class OptimizerHook(Hook):
|
|||||||
|
|
||||||
def after_train_iter(self,
|
def after_train_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[dict] = None) -> None:
|
outputs: Optional[dict] = None) -> None:
|
||||||
"""All operations need to be finished after each training iteration.
|
"""All operations need to be finished after each training iteration.
|
||||||
@ -69,12 +70,13 @@ class OptimizerHook(Hook):
|
|||||||
|
|
||||||
- Compute the gradient of model parameters.
|
- Compute the gradient of model parameters.
|
||||||
|
|
||||||
- Clip the gradidents of each parameters. (optional)
|
- Clip the gradients of each parameter. (optional)
|
||||||
|
|
||||||
- Update model parameters with gradients.
|
- Update model parameters with gradients.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. In order to keep this interface consistent
|
from dataloader. In order to keep this interface consistent
|
||||||
with other hooks, we keep ``data_batch`` here.
|
with other hooks, we keep ``data_batch`` here.
|
||||||
|
@ -10,19 +10,21 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
|||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class ParamSchedulerHook(Hook):
|
class ParamSchedulerHook(Hook):
|
||||||
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
|
"""A hook to update some hyper-parameters in optimizer, e.g., learning rate
|
||||||
and momentum."""
|
and momentum."""
|
||||||
|
|
||||||
priority = 'LOW'
|
priority = 'LOW'
|
||||||
|
|
||||||
def after_train_iter(self,
|
def after_train_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[dict] = None) -> None:
|
outputs: Optional[dict] = None) -> None:
|
||||||
"""Call step function for each scheduler after each iteration.
|
"""Call step function for each scheduler after each iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. In order to keep this interface consistent
|
from dataloader. In order to keep this interface consistent
|
||||||
with other hooks, we keep ``data_batch`` here.
|
with other hooks, we keep ``data_batch`` here.
|
||||||
|
@ -70,9 +70,8 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
||||||
from dataloader.
|
from dataloader.
|
||||||
"""
|
"""
|
||||||
self.runner._inner_iter = idx
|
self.runner.call_hook(
|
||||||
|
'before_train_iter', batch_idx=idx, data_batch=data_batch)
|
||||||
self.runner.call_hook('before_train_iter', data_batch=data_batch)
|
|
||||||
# outputs should be a dict containing one or multiple loss tensors
|
# outputs should be a dict containing one or multiple loss tensors
|
||||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||||
|
|
||||||
@ -82,6 +81,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
|
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_train_iter',
|
'after_train_iter',
|
||||||
|
batch_idx=idx,
|
||||||
data_batch=data_batch,
|
data_batch=data_batch,
|
||||||
outputs=self.runner.outputs)
|
outputs=self.runner.outputs)
|
||||||
|
|
||||||
@ -126,9 +126,6 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
if (self.runner.val_loop is not None and
|
if (self.runner.val_loop is not None and
|
||||||
self.runner._iter % self.runner.val_loop.interval == 0):
|
self.runner._iter % self.runner.val_loop.interval == 0):
|
||||||
self.runner.val_loop.run()
|
self.runner.val_loop.run()
|
||||||
# reset inner_iter to 0 to ensure it counts as expected in
|
|
||||||
# train loop
|
|
||||||
self.runner._inner_iter = 0
|
|
||||||
|
|
||||||
self.runner.call_hook('after_train_epoch')
|
self.runner.call_hook('after_train_epoch')
|
||||||
self.runner.call_hook('after_train')
|
self.runner.call_hook('after_train')
|
||||||
@ -141,7 +138,10 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
||||||
from dataloader.
|
from dataloader.
|
||||||
"""
|
"""
|
||||||
self.runner.call_hook('before_train_iter', data_batch=data_batch)
|
self.runner.call_hook(
|
||||||
|
'before_train_iter',
|
||||||
|
batch_idx=self.runner._iter,
|
||||||
|
data_batch=data_batch)
|
||||||
# outputs should be a dict containing loss tensor
|
# outputs should be a dict containing loss tensor
|
||||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||||
|
|
||||||
@ -151,10 +151,10 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
|
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_train_iter',
|
'after_train_iter',
|
||||||
|
batch_idx=self.runner._iter,
|
||||||
data_batch=data_batch,
|
data_batch=data_batch,
|
||||||
outputs=self.runner.outputs)
|
outputs=self.runner.outputs)
|
||||||
self.runner._iter += 1
|
self.runner._iter += 1
|
||||||
self.runner._inner_iter += 1
|
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
@ -208,13 +208,16 @@ class ValLoop(BaseLoop):
|
|||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
||||||
from dataloader.
|
from dataloader.
|
||||||
"""
|
"""
|
||||||
self.runner._inner_iter = idx
|
self.runner.call_hook(
|
||||||
self.runner.call_hook('before_val_iter', data_batch=data_batch)
|
'before_val_iter', batch_idx=idx, data_batch=data_batch)
|
||||||
# outputs should be sequence of BaseDataSample
|
# outputs should be sequence of BaseDataSample
|
||||||
outputs = self.runner.model(data_batch)
|
outputs = self.runner.model(data_batch)
|
||||||
self.evaluator.process(data_batch, outputs)
|
self.evaluator.process(data_batch, outputs)
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_val_iter', data_batch=data_batch, outputs=outputs)
|
'after_val_iter',
|
||||||
|
batch_idx=idx,
|
||||||
|
data_batch=data_batch,
|
||||||
|
outputs=outputs)
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
@ -263,10 +266,13 @@ class TestLoop(BaseLoop):
|
|||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data
|
||||||
from dataloader.
|
from dataloader.
|
||||||
"""
|
"""
|
||||||
self.runner._inner_iter = idx
|
self.runner.call_hook(
|
||||||
self.runner.call_hook('before_test_iter', data_batch=data_batch)
|
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
||||||
# predictions should be sequence of BaseDataSample
|
# predictions should be sequence of BaseDataSample
|
||||||
predictions = self.runner.model(data_batch)
|
predictions = self.runner.model(data_batch)
|
||||||
self.evaluator.process(data_batch, predictions)
|
self.evaluator.process(data_batch, predictions)
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_test_iter', data_batch=data_batch, outputs=predictions)
|
'after_test_iter',
|
||||||
|
batch_idx=idx,
|
||||||
|
data_batch=data_batch,
|
||||||
|
outputs=predictions)
|
||||||
|
@ -229,7 +229,6 @@ class Runner:
|
|||||||
|
|
||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
self._iter = 0
|
self._iter = 0
|
||||||
self._inner_iter = 0
|
|
||||||
|
|
||||||
# lazy initialization
|
# lazy initialization
|
||||||
training_related = [
|
training_related = [
|
||||||
@ -400,11 +399,6 @@ class Runner:
|
|||||||
"""int: Current epoch."""
|
"""int: Current epoch."""
|
||||||
return self._iter
|
return self._iter
|
||||||
|
|
||||||
@property
|
|
||||||
def inner_iter(self):
|
|
||||||
"""int: Current iteration."""
|
|
||||||
return self._inner_iter
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def launcher(self):
|
def launcher(self):
|
||||||
"""str: Way to launcher multi processes."""
|
"""str: Way to launcher multi processes."""
|
||||||
|
@ -100,19 +100,20 @@ class TestCheckpointHook:
|
|||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.work_dir = './tmp'
|
runner.work_dir = './tmp'
|
||||||
runner.iter = 9
|
runner.iter = 9
|
||||||
|
batch_idx = 9
|
||||||
runner.meta = dict()
|
runner.meta = dict()
|
||||||
runner.model = Mock()
|
runner.model = Mock()
|
||||||
|
|
||||||
# by epoch is True
|
# by epoch is True
|
||||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_run(runner)
|
||||||
checkpoint_hook.after_train_iter(runner)
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
assert runner.meta.get('hook_msgs', None) is None
|
assert runner.meta.get('hook_msgs', None) is None
|
||||||
|
|
||||||
# by epoch is False
|
# by epoch is False
|
||||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_run(runner)
|
||||||
checkpoint_hook.after_train_iter(runner)
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
assert (runner.iter + 1) % 2 == 0
|
assert (runner.iter + 1) % 2 == 0
|
||||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
||||||
|
|
||||||
@ -129,5 +130,5 @@ class TestCheckpointHook:
|
|||||||
checkpoint_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, by_epoch=False, max_keep_ckpts=1)
|
interval=2, by_epoch=False, max_keep_ckpts=1)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_run(runner)
|
||||||
checkpoint_hook.after_train_iter(runner)
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
|
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
|
||||||
|
@ -7,8 +7,8 @@ from mmengine.hooks import EmptyCacheHook
|
|||||||
class TestEmptyCacheHook:
|
class TestEmptyCacheHook:
|
||||||
|
|
||||||
def test_emtpy_cache_hook(self):
|
def test_emtpy_cache_hook(self):
|
||||||
Hook = EmptyCacheHook(True, True, True)
|
hook = EmptyCacheHook(True, True, True)
|
||||||
Runner = Mock()
|
runner = Mock()
|
||||||
Hook._after_iter(Runner)
|
hook._after_iter(runner, 0)
|
||||||
Hook._before_epoch(Runner)
|
hook._before_epoch(runner)
|
||||||
Hook._after_epoch(Runner)
|
hook._after_epoch(runner)
|
||||||
|
@ -136,11 +136,9 @@ class TestHook:
|
|||||||
|
|
||||||
def test_every_n_inner_iters(self):
|
def test_every_n_inner_iters(self):
|
||||||
hook = Hook()
|
hook = Hook()
|
||||||
runner = Mock()
|
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
runner.inner_iter = i
|
return_val = hook.every_n_inner_iters(i, 3)
|
||||||
return_val = hook.every_n_inner_iters(runner, 3)
|
|
||||||
if (i + 1) % 3 == 0:
|
if (i + 1) % 3 == 0:
|
||||||
assert return_val
|
assert return_val
|
||||||
else:
|
else:
|
||||||
@ -162,15 +160,15 @@ class TestHook:
|
|||||||
runner = Mock()
|
runner = Mock()
|
||||||
|
|
||||||
# last inner iter
|
# last inner iter
|
||||||
runner.inner_iter = 1
|
batch_idx = 1
|
||||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||||
return_val = hook.end_of_epoch(runner)
|
return_val = hook.end_of_epoch(runner, batch_idx)
|
||||||
assert return_val
|
assert return_val
|
||||||
|
|
||||||
# not the last inner iter
|
# not the last inner iter
|
||||||
runner.inner_iter = 0
|
batch_idx = 0
|
||||||
return_val = hook.end_of_epoch(runner)
|
return_val = hook.end_of_epoch(runner, batch_idx)
|
||||||
assert not return_val
|
assert not return_val
|
||||||
|
|
||||||
def test_is_last_train_epoch(self):
|
def test_is_last_train_epoch(self):
|
||||||
|
@ -7,23 +7,23 @@ from mmengine.hooks import IterTimerHook
|
|||||||
class TestIterTimerHook:
|
class TestIterTimerHook:
|
||||||
|
|
||||||
def test_before_epoch(self):
|
def test_before_epoch(self):
|
||||||
Hook = IterTimerHook()
|
hook = IterTimerHook()
|
||||||
Runner = Mock()
|
runner = Mock()
|
||||||
Hook._before_epoch(Runner)
|
hook._before_epoch(runner)
|
||||||
assert isinstance(Hook.t, float)
|
assert isinstance(hook.t, float)
|
||||||
|
|
||||||
def test_before_iter(self):
|
def test_before_iter(self):
|
||||||
Hook = IterTimerHook()
|
hook = IterTimerHook()
|
||||||
Runner = Mock()
|
runner = Mock()
|
||||||
Runner.log_buffer = dict()
|
runner.log_buffer = dict()
|
||||||
Hook._before_epoch(Runner)
|
hook._before_epoch(runner)
|
||||||
Hook._before_iter(Runner)
|
hook._before_iter(runner, 0)
|
||||||
Runner.message_hub.update_log.assert_called()
|
runner.message_hub.update_log.assert_called()
|
||||||
|
|
||||||
def test_after_iter(self):
|
def test_after_iter(self):
|
||||||
Hook = IterTimerHook()
|
hook = IterTimerHook()
|
||||||
Runner = Mock()
|
runner = Mock()
|
||||||
Runner.log_buffer = dict()
|
runner.log_buffer = dict()
|
||||||
Hook._before_epoch(Runner)
|
hook._before_epoch(runner)
|
||||||
Hook._after_iter(Runner)
|
hook._after_iter(runner, 0)
|
||||||
Runner.message_hub.update_log.assert_called()
|
runner.message_hub.update_log.assert_called()
|
||||||
|
@ -85,14 +85,15 @@ class TestLoggerHook:
|
|||||||
# Test LoggerHook by iter.
|
# Test LoggerHook by iter.
|
||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
runner.iter = 10
|
runner.iter = 10
|
||||||
|
batch_idx = 5
|
||||||
logger_hook = LoggerHook(by_epoch=False)
|
logger_hook = LoggerHook(by_epoch=False)
|
||||||
logger_hook._log_train = MagicMock()
|
logger_hook._log_train = MagicMock()
|
||||||
logger_hook.after_train_iter(runner)
|
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
# `cur_iter=10+1`, which cannot be exact division by
|
# `cur_iter=10+1`, which cannot be exact division by
|
||||||
# `logger_hook.interval`
|
# `logger_hook.interval`
|
||||||
logger_hook._log_train.assert_not_called()
|
logger_hook._log_train.assert_not_called()
|
||||||
runner.iter = 9
|
runner.iter = 9
|
||||||
logger_hook.after_train_iter(runner)
|
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
logger_hook._log_train.assert_called()
|
logger_hook._log_train.assert_called()
|
||||||
|
|
||||||
# Test LoggerHook by epoch.
|
# Test LoggerHook by epoch.
|
||||||
@ -100,19 +101,19 @@ class TestLoggerHook:
|
|||||||
logger_hook._log_train = MagicMock()
|
logger_hook._log_train = MagicMock()
|
||||||
# Only `runner.inner_iter` will work.
|
# Only `runner.inner_iter` will work.
|
||||||
runner.iter = 9
|
runner.iter = 9
|
||||||
runner.inner_iter = 10
|
batch_idx = 10
|
||||||
logger_hook.after_train_iter(runner)
|
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
logger_hook._log_train.assert_not_called()
|
logger_hook._log_train.assert_not_called()
|
||||||
runner.inner_iter = 9
|
batch_idx = 9
|
||||||
logger_hook.after_train_iter(runner)
|
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
logger_hook._log_train.assert_called()
|
logger_hook._log_train.assert_called()
|
||||||
|
|
||||||
# Test end of the epoch.
|
# Test end of the epoch.
|
||||||
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
|
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
|
||||||
logger_hook._log_train = MagicMock()
|
logger_hook._log_train = MagicMock()
|
||||||
runner.cur_dataloader = [0] * 5
|
runner.cur_dataloader = [0] * 5
|
||||||
runner.inner_iter = 4
|
batch_idx = 4
|
||||||
logger_hook.after_train_iter(runner)
|
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
logger_hook._log_train.assert_called()
|
logger_hook._log_train.assert_called()
|
||||||
|
|
||||||
# Test print exp_name
|
# Test print exp_name
|
||||||
@ -120,7 +121,7 @@ class TestLoggerHook:
|
|||||||
logger_hook = LoggerHook()
|
logger_hook = LoggerHook()
|
||||||
runner.logger = MagicMock()
|
runner.logger = MagicMock()
|
||||||
logger_hook._log_train = MagicMock()
|
logger_hook._log_train = MagicMock()
|
||||||
logger_hook.after_train_iter(runner)
|
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
runner.logger.info.assert_called_with(
|
runner.logger.info.assert_called_with(
|
||||||
f'Exp name: {runner.meta["exp_name"]}')
|
f'Exp name: {runner.meta["exp_name"]}')
|
||||||
|
|
||||||
@ -137,6 +138,7 @@ class TestLoggerHook:
|
|||||||
runner.meta = dict(exp_name='retinanet')
|
runner.meta = dict(exp_name='retinanet')
|
||||||
# Prepare LoggerHook
|
# Prepare LoggerHook
|
||||||
logger_hook = LoggerHook(by_epoch=by_epoch)
|
logger_hook = LoggerHook(by_epoch=by_epoch)
|
||||||
|
logger_hook._inner_iter = 1
|
||||||
logger_hook.writer = MagicMock()
|
logger_hook.writer = MagicMock()
|
||||||
logger_hook.time_sec_tot = 1000
|
logger_hook.time_sec_tot = 1000
|
||||||
logger_hook.start_iter = 0
|
logger_hook.start_iter = 0
|
||||||
@ -220,6 +222,7 @@ class TestLoggerHook:
|
|||||||
def test_get_window_size(self):
|
def test_get_window_size(self):
|
||||||
runner = self._setup_runner()
|
runner = self._setup_runner()
|
||||||
logger_hook = LoggerHook()
|
logger_hook = LoggerHook()
|
||||||
|
logger_hook._inner_iter = 1
|
||||||
# Test get window size by name.
|
# Test get window size by name.
|
||||||
assert logger_hook._get_window_size(runner, 'epoch') == 2
|
assert logger_hook._get_window_size(runner, 'epoch') == 2
|
||||||
assert logger_hook._get_window_size(runner, 'global') == 11
|
assert logger_hook._get_window_size(runner, 'global') == 11
|
||||||
@ -313,6 +316,7 @@ class TestLoggerHook:
|
|||||||
def test_get_iter(self):
|
def test_get_iter(self):
|
||||||
runner = self._setup_runner()
|
runner = self._setup_runner()
|
||||||
logger_hook = LoggerHook()
|
logger_hook = LoggerHook()
|
||||||
|
logger_hook._inner_iter = 1
|
||||||
# Get global iter when `inner_iter=False`
|
# Get global iter when `inner_iter=False`
|
||||||
iter = logger_hook._get_iter(runner)
|
iter = logger_hook._get_iter(runner)
|
||||||
assert iter == 11
|
assert iter == 11
|
||||||
@ -338,7 +342,6 @@ class TestLoggerHook:
|
|||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
runner.epoch = 1
|
runner.epoch = 1
|
||||||
runner.cur_dataloader = [0] * 5
|
runner.cur_dataloader = [0] * 5
|
||||||
runner.inner_iter = 1
|
|
||||||
runner.iter = 10
|
runner.iter = 10
|
||||||
runner.train_loop.max_iters = 50
|
runner.train_loop.max_iters = 50
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
@ -11,9 +11,10 @@ class TestNaiveVisualizationHook:
|
|||||||
|
|
||||||
def test_after_train_iter(self):
|
def test_after_train_iter(self):
|
||||||
naive_visualization_hook = NaiveVisualizationHook()
|
naive_visualization_hook = NaiveVisualizationHook()
|
||||||
Runner = Mock(iter=1)
|
runner = Mock(iter=1)
|
||||||
Runner.writer.add_image = Mock()
|
runner.writer.add_image = Mock()
|
||||||
inputs = torch.randn(1, 3, 15, 15)
|
inputs = torch.randn(1, 3, 15, 15)
|
||||||
|
batch_idx = 10
|
||||||
# test with normalize, resize, pad
|
# test with normalize, resize, pad
|
||||||
gt_datasamples = [
|
gt_datasamples = [
|
||||||
BaseDataSample(
|
BaseDataSample(
|
||||||
@ -28,7 +29,7 @@ class TestNaiveVisualizationHook:
|
|||||||
]
|
]
|
||||||
pred_datasamples = [BaseDataSample()]
|
pred_datasamples = [BaseDataSample()]
|
||||||
data_batch = (inputs, gt_datasamples)
|
data_batch = (inputs, gt_datasamples)
|
||||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||||
pred_datasamples)
|
pred_datasamples)
|
||||||
# test with resize, pad
|
# test with resize, pad
|
||||||
gt_datasamples = [
|
gt_datasamples = [
|
||||||
@ -42,7 +43,7 @@ class TestNaiveVisualizationHook:
|
|||||||
]
|
]
|
||||||
pred_datasamples = [BaseDataSample()]
|
pred_datasamples = [BaseDataSample()]
|
||||||
data_batch = (inputs, gt_datasamples)
|
data_batch = (inputs, gt_datasamples)
|
||||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||||
pred_datasamples)
|
pred_datasamples)
|
||||||
# test with only resize
|
# test with only resize
|
||||||
gt_datasamples = [
|
gt_datasamples = [
|
||||||
@ -55,7 +56,7 @@ class TestNaiveVisualizationHook:
|
|||||||
]
|
]
|
||||||
pred_datasamples = [BaseDataSample()]
|
pred_datasamples = [BaseDataSample()]
|
||||||
data_batch = (inputs, gt_datasamples)
|
data_batch = (inputs, gt_datasamples)
|
||||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||||
pred_datasamples)
|
pred_datasamples)
|
||||||
|
|
||||||
# test with only pad
|
# test with only pad
|
||||||
@ -69,7 +70,7 @@ class TestNaiveVisualizationHook:
|
|||||||
]
|
]
|
||||||
pred_datasamples = [BaseDataSample()]
|
pred_datasamples = [BaseDataSample()]
|
||||||
data_batch = (inputs, gt_datasamples)
|
data_batch = (inputs, gt_datasamples)
|
||||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||||
pred_datasamples)
|
pred_datasamples)
|
||||||
|
|
||||||
# test no transform
|
# test no transform
|
||||||
@ -80,5 +81,5 @@ class TestNaiveVisualizationHook:
|
|||||||
]
|
]
|
||||||
pred_datasamples = [BaseDataSample()]
|
pred_datasamples = [BaseDataSample()]
|
||||||
data_batch = (inputs, gt_datasamples)
|
data_batch = (inputs, gt_datasamples)
|
||||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||||
pred_datasamples)
|
pred_datasamples)
|
||||||
|
@ -73,7 +73,7 @@ class TestOptimizerHook:
|
|||||||
wraps=optimizer_hook.detect_anomalous_parameters)
|
wraps=optimizer_hook.detect_anomalous_parameters)
|
||||||
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
|
optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
|
||||||
|
|
||||||
optimizer_hook.after_train_iter(dummy_runner)
|
optimizer_hook.after_train_iter(dummy_runner, 0)
|
||||||
# assert the parameters of conv2 and conv3 are not in the
|
# assert the parameters of conv2 and conv3 are not in the
|
||||||
# computational graph which is with x1.sum() as root.
|
# computational graph which is with x1.sum() as root.
|
||||||
assert 'conv2.weight' in dummy_runner.logger.msg
|
assert 'conv2.weight' in dummy_runner.logger.msg
|
||||||
@ -89,7 +89,7 @@ class TestOptimizerHook:
|
|||||||
|
|
||||||
dummy_runner.outputs['loss'] = model(x)[1].sum()
|
dummy_runner.outputs['loss'] = model(x)[1].sum()
|
||||||
dummy_runner.logger.msg = ''
|
dummy_runner.logger.msg = ''
|
||||||
optimizer_hook.after_train_iter(dummy_runner)
|
optimizer_hook.after_train_iter(dummy_runner, 0)
|
||||||
# assert the parameters of conv3 are not in the computational graph
|
# assert the parameters of conv3 are not in the computational graph
|
||||||
assert 'conv3.weight' in dummy_runner.logger.msg
|
assert 'conv3.weight' in dummy_runner.logger.msg
|
||||||
assert 'conv3.bias' in dummy_runner.logger.msg
|
assert 'conv3.bias' in dummy_runner.logger.msg
|
||||||
@ -107,7 +107,7 @@ class TestOptimizerHook:
|
|||||||
dummy_runner.outputs['loss'].backward = Mock(
|
dummy_runner.outputs['loss'].backward = Mock(
|
||||||
wraps=dummy_runner.outputs['loss'].backward)
|
wraps=dummy_runner.outputs['loss'].backward)
|
||||||
|
|
||||||
optimizer_hook.after_train_iter(dummy_runner)
|
optimizer_hook.after_train_iter(dummy_runner, 0)
|
||||||
|
|
||||||
dummy_runner.optimizer.step.assert_called()
|
dummy_runner.optimizer.step.assert_called()
|
||||||
dummy_runner.outputs['loss'].backward.assert_called()
|
dummy_runner.outputs['loss'].backward.assert_called()
|
||||||
|
@ -7,21 +7,21 @@ from mmengine.hooks import ParamSchedulerHook
|
|||||||
class TestParamSchedulerHook:
|
class TestParamSchedulerHook:
|
||||||
|
|
||||||
def test_after_iter(self):
|
def test_after_iter(self):
|
||||||
Hook = ParamSchedulerHook()
|
hook = ParamSchedulerHook()
|
||||||
Runner = Mock()
|
runner = Mock()
|
||||||
scheduler = Mock()
|
scheduler = Mock()
|
||||||
scheduler.step = Mock()
|
scheduler.step = Mock()
|
||||||
scheduler.by_epoch = False
|
scheduler.by_epoch = False
|
||||||
Runner.param_schedulers = [scheduler]
|
runner.param_schedulers = [scheduler]
|
||||||
Hook.after_train_iter(Runner)
|
hook.after_train_iter(runner, 0)
|
||||||
scheduler.step.assert_called()
|
scheduler.step.assert_called()
|
||||||
|
|
||||||
def test_after_epoch(self):
|
def test_after_epoch(self):
|
||||||
Hook = ParamSchedulerHook()
|
hook = ParamSchedulerHook()
|
||||||
Runner = Mock()
|
runner = Mock()
|
||||||
scheduler = Mock()
|
scheduler = Mock()
|
||||||
scheduler.step = Mock()
|
scheduler.step = Mock()
|
||||||
scheduler.by_epoch = True
|
scheduler.by_epoch = True
|
||||||
Runner.param_schedulers = [scheduler]
|
runner.param_schedulers = [scheduler]
|
||||||
Hook.after_train_epoch(Runner)
|
hook.after_train_epoch(runner)
|
||||||
scheduler.step.assert_called()
|
scheduler.step.assert_called()
|
||||||
|
@ -7,7 +7,7 @@ from mmengine.hooks import SyncBuffersHook
|
|||||||
class TestSyncBuffersHook:
|
class TestSyncBuffersHook:
|
||||||
|
|
||||||
def test_sync_buffers_hook(self):
|
def test_sync_buffers_hook(self):
|
||||||
Runner = Mock()
|
runner = Mock()
|
||||||
Runner.model = Mock()
|
runner.model = Mock()
|
||||||
Hook = SyncBuffersHook()
|
hook = SyncBuffersHook()
|
||||||
Hook._after_epoch(Runner)
|
hook._after_epoch(runner)
|
||||||
|
@ -720,8 +720,8 @@ class TestRunner(TestCase):
|
|||||||
epoch_targets = [i for i in range(3)]
|
epoch_targets = [i for i in range(3)]
|
||||||
iter_results = []
|
iter_results = []
|
||||||
iter_targets = [i for i in range(4 * 3)]
|
iter_targets = [i for i in range(4 * 3)]
|
||||||
inner_iter_results = []
|
batch_idx_results = []
|
||||||
inner_iter_targets = [i for i in range(4)] * 3 # train and val
|
batch_idx_targets = [i for i in range(4)] * 3 # train and val
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class TestEpochHook(Hook):
|
class TestEpochHook(Hook):
|
||||||
@ -729,9 +729,9 @@ class TestRunner(TestCase):
|
|||||||
def before_train_epoch(self, runner):
|
def before_train_epoch(self, runner):
|
||||||
epoch_results.append(runner.epoch)
|
epoch_results.append(runner.epoch)
|
||||||
|
|
||||||
def before_train_iter(self, runner, data_batch=None):
|
def before_train_iter(self, runner, batch_idx, data_batch=None):
|
||||||
iter_results.append(runner.iter)
|
iter_results.append(runner.iter)
|
||||||
inner_iter_results.append(runner.inner_iter)
|
batch_idx_results.append(batch_idx)
|
||||||
|
|
||||||
self.epoch_based_cfg.custom_hooks = [
|
self.epoch_based_cfg.custom_hooks = [
|
||||||
dict(type='TestEpochHook', priority=50)
|
dict(type='TestEpochHook', priority=50)
|
||||||
@ -746,7 +746,7 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
for result, target, in zip(iter_results, iter_targets):
|
for result, target, in zip(iter_results, iter_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
for result, target, in zip(inner_iter_results, inner_iter_targets):
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
@ -754,9 +754,9 @@ class TestRunner(TestCase):
|
|||||||
# 3. test iter and epoch counter of IterBasedTrainLoop
|
# 3. test iter and epoch counter of IterBasedTrainLoop
|
||||||
epoch_results = []
|
epoch_results = []
|
||||||
iter_results = []
|
iter_results = []
|
||||||
inner_iter_results = []
|
batch_idx_results = []
|
||||||
iter_targets = [i for i in range(12)]
|
iter_targets = [i for i in range(12)]
|
||||||
inner_iter_targets = [0, 1, 2, 3] * 3
|
batch_idx_targets = [i for i in range(12)]
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class TestIterHook(Hook):
|
class TestIterHook(Hook):
|
||||||
@ -764,9 +764,9 @@ class TestRunner(TestCase):
|
|||||||
def before_train_epoch(self, runner):
|
def before_train_epoch(self, runner):
|
||||||
epoch_results.append(runner.epoch)
|
epoch_results.append(runner.epoch)
|
||||||
|
|
||||||
def before_train_iter(self, runner, data_batch=None):
|
def before_train_iter(self, runner, batch_idx, data_batch=None):
|
||||||
iter_results.append(runner.iter)
|
iter_results.append(runner.iter)
|
||||||
inner_iter_results.append(runner.inner_iter)
|
batch_idx_results.append(batch_idx)
|
||||||
|
|
||||||
self.iter_based_cfg.custom_hooks = [
|
self.iter_based_cfg.custom_hooks = [
|
||||||
dict(type='TestIterHook', priority=50)
|
dict(type='TestIterHook', priority=50)
|
||||||
@ -781,7 +781,7 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(epoch_results[0], 0)
|
self.assertEqual(epoch_results[0], 0)
|
||||||
for result, target, in zip(iter_results, iter_targets):
|
for result, target, in zip(iter_results, iter_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
for result, target, in zip(inner_iter_results, inner_iter_targets):
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
def test_val(self):
|
def test_val(self):
|
||||||
@ -1056,7 +1056,6 @@ class TestRunner(TestCase):
|
|||||||
runner.resume(path)
|
runner.resume(path)
|
||||||
self.assertEqual(runner.epoch, 0)
|
self.assertEqual(runner.epoch, 0)
|
||||||
self.assertEqual(runner.iter, 12)
|
self.assertEqual(runner.iter, 12)
|
||||||
self.assertEqual(runner.inner_iter, 0)
|
|
||||||
self.assertTrue(runner._has_loaded)
|
self.assertTrue(runner._has_loaded)
|
||||||
self.assertIsInstance(runner.optimizer, SGD)
|
self.assertIsInstance(runner.optimizer, SGD)
|
||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user