[Refactor] Refactor the interfaces of Hook and its subclassed (#117)
* Fix hook * Fix * Fix docs * FIx * Fix * Fix as comment * update * Fix hook * Fix hook * Fix hook * Fix itertimerhook * Fix iter_timer_hook * Fix * Fix * fix logger hook * Fix loggerhook * update cur_dataloader * Fix docstring * Fix docstring * Fix as commet * Fix as commet * Fix as comment * rename is_last_epoch, enhance and add after_val before_val .etc * fix typo in docstring * remove resolved TODO * refactor docstringpull/127/head
parent
755f8b5b59
commit
a7961407e4
|
@ -119,9 +119,8 @@ class CheckpointHook(Hook):
|
|||
# save checkpoint for following cases:
|
||||
# 1. every ``self.interval`` epochs
|
||||
# 2. reach the last epoch of training
|
||||
if self.every_n_epochs(
|
||||
runner, self.interval) or (self.save_last
|
||||
and self.is_last_epoch(runner)):
|
||||
if self.every_n_epochs(runner, self.interval) or (
|
||||
self.save_last and self.is_last_train_epoch(runner)):
|
||||
runner.logger.info(f'Saving checkpoint at \
|
||||
{runner.epoch + 1} epochs')
|
||||
if self.sync_buffer:
|
||||
|
@ -187,9 +186,8 @@ class CheckpointHook(Hook):
|
|||
# save checkpoint for following cases:
|
||||
# 1. every ``self.interval`` iterations
|
||||
# 2. reach the last iteration of training
|
||||
if self.every_n_iters(
|
||||
runner, self.interval) or (self.save_last
|
||||
and self.is_last_iter(runner)):
|
||||
if self.every_n_iters(runner, self.interval) or \
|
||||
(self.save_last and self.is_last_iter(runner, mode='train')):
|
||||
runner.logger.info(f'Saving checkpoint at \
|
||||
{runner.iter + 1} iterations')
|
||||
if self.sync_buffer:
|
||||
|
|
|
@ -30,16 +30,16 @@ class EmptyCacheHook(Hook):
|
|||
before_epoch: bool = False,
|
||||
after_epoch: bool = True,
|
||||
after_iter: bool = False) -> None:
|
||||
self._before_epoch = before_epoch
|
||||
self._after_epoch = after_epoch
|
||||
self._after_iter = after_iter
|
||||
self._do_before_epoch = before_epoch
|
||||
self._do_after_epoch = after_epoch
|
||||
self._do_after_iter = after_iter
|
||||
|
||||
def after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs:
|
||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None)\
|
||||
-> None:
|
||||
def _after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Union[dict,
|
||||
Sequence[BaseDataSample]]] = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""Empty cache after an iteration.
|
||||
|
||||
Args:
|
||||
|
@ -48,24 +48,27 @@ class EmptyCacheHook(Hook):
|
|||
from dataloader. Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
if self._after_iter:
|
||||
if self._do_after_iter:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def before_epoch(self, runner) -> None:
|
||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""Empty cache before an epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
if self._before_epoch:
|
||||
if self._do_before_epoch:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def after_epoch(self, runner) -> None:
|
||||
def _after_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""Empty cache after an epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
if self._after_epoch:
|
||||
if self._do_after_epoch:
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -16,20 +16,20 @@ class Hook:
|
|||
|
||||
def before_run(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before the training process.
|
||||
operations before the training validation or testing process.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training/validation/testing
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_run(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after the training process.
|
||||
operations before the training validation or testing process.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training/validation/testing
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
"""
|
||||
pass
|
||||
|
@ -54,7 +54,7 @@ class Hook:
|
|||
|
||||
def before_val(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before val.
|
||||
operations before validation.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
|
@ -63,7 +63,7 @@ class Hook:
|
|||
|
||||
def after_val(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after val.
|
||||
operations after validation.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
|
@ -72,7 +72,7 @@ class Hook:
|
|||
|
||||
def before_test(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before test.
|
||||
operations before testing.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
|
@ -81,67 +81,21 @@ class Hook:
|
|||
|
||||
def after_test(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after test.
|
||||
operations after testing.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each iter.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs:
|
||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
|
||||
-> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||
to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before saving the checkpoint.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
checkpoints (dict): Model's checkpoint.
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
checkpoint (dict): Model's checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -150,8 +104,9 @@ class Hook:
|
|||
operations after loading the checkpoint.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
checkpoints (dict): Model's checkpoint.
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
checkpoint (dict): Model's checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -162,25 +117,25 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
self.before_epoch(runner)
|
||||
self._before_epoch(runner, mode='train')
|
||||
|
||||
def before_val_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each validation epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the validation process.
|
||||
"""
|
||||
self.before_epoch(runner)
|
||||
self._before_epoch(runner, mode='val')
|
||||
|
||||
def before_test_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each test epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
self.before_epoch(runner)
|
||||
self._before_epoch(runner, mode='test')
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -189,25 +144,25 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
self.after_epoch(runner)
|
||||
self._after_epoch(runner, mode='train')
|
||||
|
||||
def after_val_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the validation process.
|
||||
"""
|
||||
self.after_epoch(runner)
|
||||
self._after_epoch(runner, mode='val')
|
||||
|
||||
def after_test_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each test epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
self.after_epoch(runner)
|
||||
self._after_epoch(runner, mode='test')
|
||||
|
||||
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -218,29 +173,29 @@ class Hook:
|
|||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
self._before_iter(runner, data_batch=data_batch, mode='train')
|
||||
|
||||
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each validation iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the validation process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
self._before_iter(runner, data_batch=data_batch, mode='val')
|
||||
|
||||
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each test iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the testing process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
self._before_iter(runner, data_batch=data_batch, mode='test')
|
||||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
|
@ -256,7 +211,8 @@ class Hook:
|
|||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.after_iter(runner, data_batch=None, outputs=None)
|
||||
self._after_iter(
|
||||
runner, data_batch=data_batch, outputs=outputs, mode='train')
|
||||
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
|
@ -267,13 +223,14 @@ class Hook:
|
|||
operations after each validation iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the validation process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from
|
||||
model. Defaults to None.
|
||||
"""
|
||||
self.after_iter(runner, data_batch=None, outputs=None)
|
||||
self._after_iter(
|
||||
runner, data_batch=data_batch, outputs=outputs, mode='val')
|
||||
|
||||
def after_test_iter(
|
||||
self,
|
||||
|
@ -284,48 +241,108 @@ class Hook:
|
|||
operations after each test iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.after_iter(runner, data_batch=None, outputs=None)
|
||||
self._after_iter(
|
||||
runner, data_batch=data_batch, outputs=outputs, mode='test')
|
||||
|
||||
def every_n_epochs(self, runner, n: int) -> bool:
|
||||
"""Test whether or not current epoch can be evenly divided by n.
|
||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
n (int): Whether or not current epoch can be evenly divided by n.
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _after_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _before_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each iter.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Union[Sequence[BaseDataSample],
|
||||
dict]] = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
||||
def every_n_epochs(self, runner, n: int) -> bool:
|
||||
"""Test whether current epoch can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
n (int): Whether current epoch can be evenly divided by n.
|
||||
|
||||
Returns:
|
||||
bool: whether or not current epoch can be evenly divided by n.
|
||||
bool: Whether current epoch can be evenly divided by n.
|
||||
"""
|
||||
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_inner_iters(self, runner, n: int) -> bool:
|
||||
"""Test whether or not current inner iteration can be evenly divided by
|
||||
n.
|
||||
"""Test whether current inner iteration can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
n (int): Whether or not current inner iteration can be evenly
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
n (int): Whether current inner iteration can be evenly
|
||||
divided by n.
|
||||
|
||||
Returns:
|
||||
bool: whether or not current inner iteration can be evenly
|
||||
bool: Whether current inner iteration can be evenly
|
||||
divided by n.
|
||||
"""
|
||||
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_iters(self, runner, n: int) -> bool:
|
||||
"""Test whether or not current iteration can be evenly divided by n.
|
||||
"""Test whether current iteration can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
n (int): Whether or not current iteration can be
|
||||
evenly divided by n.
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
n (int): Whether current iteration can be evenly divided by n.
|
||||
|
||||
Returns:
|
||||
bool: Return True if the current iteration can be evenly divided
|
||||
|
@ -334,35 +351,46 @@ class Hook:
|
|||
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def end_of_epoch(self, runner) -> bool:
|
||||
"""Check whether the current epoch reaches the `max_epochs` or not.
|
||||
"""Check whether the current iteration reaches the last iteration of
|
||||
current dataloader.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
|
||||
Returns:
|
||||
bool: Whether reaches the end of current epoch or not.
|
||||
"""
|
||||
return runner.inner_iter + 1 == len(runner.cur_dataloader)
|
||||
|
||||
def is_last_train_epoch(self, runner) -> bool:
|
||||
"""Test whether current epoch is the last train epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
|
||||
Returns:
|
||||
bool: whether the end of current epoch or not.
|
||||
bool: Whether reaches the end of training epoch.
|
||||
"""
|
||||
return runner.inner_iter + 1 == len(runner.data_loader)
|
||||
return runner.epoch + 1 == runner.train_loop.max_epochs
|
||||
|
||||
def is_last_epoch(self, runner) -> bool:
|
||||
"""Test whether or not current epoch is the last epoch.
|
||||
def is_last_iter(self, runner, mode='train') -> bool:
|
||||
"""Test whether current iteration is the last iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
|
||||
Returns:
|
||||
bool: bool: Return True if the current epoch reaches the
|
||||
`max_epochs`, otherwise False.
|
||||
bool: Whether current iteration is the last iteration.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
return runner.epoch + 1 == runner._max_epochs
|
||||
|
||||
def is_last_iter(self, runner) -> bool:
|
||||
"""Test whether or not current epoch is the last iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
|
||||
Returns:
|
||||
bool: whether or not current iteration is the last iteration.
|
||||
"""
|
||||
return runner.iter + 1 == runner._max_iters
|
||||
if mode == 'train':
|
||||
return runner.iter + 1 == runner.train_loop.max_iters
|
||||
elif mode == 'val':
|
||||
return runner.iter + 1 == runner.val_loop.max_iters
|
||||
elif mode == 'test':
|
||||
return runner.iter + 1 == runner.test_loop.max_iters
|
||||
else:
|
||||
raise ValueError('mode should be train, val or test but got'
|
||||
f'{mode}')
|
||||
|
|
|
@ -18,30 +18,37 @@ class IterTimerHook(Hook):
|
|||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_epoch(self, runner) -> None:
|
||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""Record time flag before start a epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
self.t = time.time()
|
||||
|
||||
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
def _before_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""Logging time for loading data and update the time flag.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
runner.log_buffer.update({'data_time': time.time() - self.t})
|
||||
runner.message_hub.update_log(f'{mode}/data_time',
|
||||
time.time() - self.t)
|
||||
|
||||
def after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs:
|
||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
|
||||
def _after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs:
|
||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None,
|
||||
mode: str = 'train') \
|
||||
-> None:
|
||||
"""Logging time for a iteration and update the time flag.
|
||||
|
||||
|
@ -51,7 +58,9 @@ class IterTimerHook(Hook):
|
|||
from dataloader. Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||
to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
runner.log_buffer.update({'time': time.time() - self.t})
|
||||
|
||||
runner.message_hub.update_log(f'{mode}/time', time.time() - self.t)
|
||||
self.t = time.time()
|
||||
|
|
|
@ -264,14 +264,15 @@ class LoggerHook(Hook):
|
|||
# by iter: Iter [100/100000]
|
||||
if self.by_epoch:
|
||||
log_str = f'Epoch [{cur_epoch}]' \
|
||||
f'[{cur_iter}/{len(runner.data_loader)}]\t'
|
||||
f'[{cur_iter}/{len(runner.cur_dataloader)}]\t'
|
||||
else:
|
||||
log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t'
|
||||
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t'
|
||||
log_str += f'{lr_momentum_str}, '
|
||||
# Calculate eta time.
|
||||
self.time_sec_tot += (tag['time'] * self.interval)
|
||||
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
|
||||
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
||||
eta_sec = time_sec_avg * (
|
||||
runner.train_loop.max_iters - runner.iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
log_str += f'eta: {eta_str}, '
|
||||
log_str += f'time: {tag["time"]:.3f}, ' \
|
||||
|
@ -302,7 +303,7 @@ class LoggerHook(Hook):
|
|||
"""
|
||||
tag = self._collect_info(runner, 'val')
|
||||
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
|
||||
eval_iter = len(runner.data_loader)
|
||||
eval_iter = len(runner.cur_dataloader)
|
||||
cur_iter = self._get_iter(runner)
|
||||
cur_epoch = self._get_epoch(runner, 'val')
|
||||
# val/test time
|
||||
|
|
|
@ -14,15 +14,15 @@ class DistSamplerSeedHook(Hook):
|
|||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_epoch(self, runner) -> None:
|
||||
def before_train_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""Set the seed for sampler and batch_sampler.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if hasattr(runner.data_loader.sampler, 'set_epoch'):
|
||||
if hasattr(runner.cur_dataloader.sampler, 'set_epoch'):
|
||||
# in case the data loader uses `SequentialSampler` in Pytorch
|
||||
runner.data_loader.sampler.set_epoch(runner.epoch)
|
||||
elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
|
||||
runner.cur_dataloader.sampler.set_epoch(runner.epoch)
|
||||
elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'):
|
||||
# batch sampler in pytorch warps the sampler as its attributes.
|
||||
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
|
||||
runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch)
|
||||
|
|
|
@ -89,7 +89,7 @@ class SyncBuffersHook(Hook):
|
|||
def __init__(self) -> None:
|
||||
self.distributed = dist.IS_DIST
|
||||
|
||||
def after_epoch(self, runner) -> None:
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""All-reduce model buffers at the end of each epoch.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,6 @@ class TestEmptyCacheHook:
|
|||
def test_emtpy_cache_hook(self):
|
||||
Hook = EmptyCacheHook(True, True, True)
|
||||
Runner = Mock()
|
||||
Hook.after_iter(Runner)
|
||||
Hook.before_epoch(Runner)
|
||||
Hook.after_epoch(Runner)
|
||||
Hook._after_iter(Runner)
|
||||
Hook._before_epoch(Runner)
|
||||
Hook._after_epoch(Runner)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
|
||||
|
@ -19,25 +21,25 @@ class TestHook:
|
|||
def test_before_epoch(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
hook.before_epoch(runner)
|
||||
hook._before_epoch(runner)
|
||||
|
||||
def test_after_epoch(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
hook.after_epoch(runner)
|
||||
hook._after_epoch(runner)
|
||||
|
||||
def test_before_iter(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
data_batch = {}
|
||||
hook.before_iter(runner, data_batch)
|
||||
hook._before_iter(runner, data_batch)
|
||||
|
||||
def test_after_iter(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
data_batch = {}
|
||||
outputs = {}
|
||||
hook.after_iter(runner, data_batch, outputs)
|
||||
hook._after_iter(runner, data_batch, outputs)
|
||||
|
||||
def test_before_save_checkpoint(self):
|
||||
hook = Hook()
|
||||
|
@ -161,7 +163,8 @@ class TestHook:
|
|||
|
||||
# last inner iter
|
||||
runner.inner_iter = 1
|
||||
runner.data_loader.__len__ = Mock(return_value=2)
|
||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||
return_val = hook.end_of_epoch(runner)
|
||||
assert return_val
|
||||
|
||||
|
@ -170,19 +173,19 @@ class TestHook:
|
|||
return_val = hook.end_of_epoch(runner)
|
||||
assert not return_val
|
||||
|
||||
def test_is_last_epoch(self):
|
||||
def test_is_last_train_epoch(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
|
||||
# last epoch
|
||||
runner.epoch = 1
|
||||
runner._max_epochs = 2
|
||||
return_val = hook.is_last_epoch(runner)
|
||||
runner.train_loop.max_epochs = 2
|
||||
return_val = hook.is_last_train_epoch(runner)
|
||||
assert return_val
|
||||
|
||||
# not the last epoch
|
||||
runner.epoch = 0
|
||||
return_val = hook.is_last_epoch(runner)
|
||||
runner.train_loop.max_epochs = 0
|
||||
return_val = hook.is_last_train_epoch(runner)
|
||||
assert not return_val
|
||||
|
||||
def test_is_last_iter(self):
|
||||
|
@ -191,11 +194,18 @@ class TestHook:
|
|||
|
||||
# last iter
|
||||
runner.iter = 1
|
||||
runner._max_iters = 2
|
||||
runner.train_loop.max_iters = 2
|
||||
return_val = hook.is_last_iter(runner)
|
||||
assert return_val
|
||||
|
||||
# not the last iter
|
||||
runner.iter = 0
|
||||
return_val = hook.is_last_iter(runner)
|
||||
runner.val_loop.max_iters = 0
|
||||
return_val = hook.is_last_iter(runner, mode='val')
|
||||
assert not return_val
|
||||
|
||||
runner.test_loop.max_iters = 0
|
||||
return_val = hook.is_last_iter(runner, mode='test')
|
||||
assert not return_val
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
hook.is_last_iter(runner, mode='error_mode')
|
||||
|
|
|
@ -9,21 +9,21 @@ class TestIterTimerHook:
|
|||
def test_before_epoch(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Hook.before_epoch(Runner)
|
||||
Hook._before_epoch(Runner)
|
||||
assert isinstance(Hook.t, float)
|
||||
|
||||
def test_before_iter(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Runner.log_buffer = dict()
|
||||
Hook.before_epoch(Runner)
|
||||
Hook.before_iter(Runner)
|
||||
assert 'data_time' in Runner.log_buffer
|
||||
Hook._before_epoch(Runner)
|
||||
Hook._before_iter(Runner)
|
||||
Runner.message_hub.update_log.assert_called()
|
||||
|
||||
def test_after_iter(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Runner.log_buffer = dict()
|
||||
Hook.before_epoch(Runner)
|
||||
Hook.after_iter(Runner)
|
||||
assert 'time' in Runner.log_buffer
|
||||
Hook._before_epoch(Runner)
|
||||
Hook._after_iter(Runner)
|
||||
Runner.message_hub.update_log.assert_called()
|
||||
|
|
|
@ -110,7 +110,7 @@ class TestLoggerHook:
|
|||
# Test end of the epoch.
|
||||
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
|
||||
logger_hook._log_train = MagicMock()
|
||||
runner.data_loader = [0] * 5
|
||||
runner.cur_dataloader = [0] * 5
|
||||
runner.inner_iter = 4
|
||||
logger_hook.after_train_iter(runner)
|
||||
logger_hook._log_train.assert_called()
|
||||
|
@ -155,7 +155,7 @@ class TestLoggerHook:
|
|||
out, _ = capsys.readouterr()
|
||||
time_avg = logger_hook.time_sec_tot / (
|
||||
runner.iter + 1 - logger_hook.start_iter)
|
||||
eta_second = time_avg * (runner.max_iters - runner.iter - 1)
|
||||
eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_second)))
|
||||
if by_epoch:
|
||||
if torch.cuda.is_available():
|
||||
|
@ -337,10 +337,10 @@ class TestLoggerHook:
|
|||
def _setup_runner(self):
|
||||
runner = MagicMock()
|
||||
runner.epoch = 1
|
||||
runner.data_loader = [0] * 5
|
||||
runner.cur_dataloader = [0] * 5
|
||||
runner.inner_iter = 1
|
||||
runner.iter = 10
|
||||
runner.max_iters = 50
|
||||
runner.train_loop.max_iters = 50
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
for handler in logger.handlers:
|
||||
|
|
|
@ -12,17 +12,17 @@ class TestDistSamplerSeedHook:
|
|||
# Test dataset sampler
|
||||
runner = Mock()
|
||||
runner.epoch = 1
|
||||
runner.data_loader = Mock()
|
||||
runner.data_loader.sampler = Mock()
|
||||
runner.data_loader.sampler.set_epoch = Mock()
|
||||
hook.before_epoch(runner)
|
||||
runner.data_loader.sampler.set_epoch.assert_called()
|
||||
runner.cur_dataloader = Mock()
|
||||
runner.cur_dataloader.sampler = Mock()
|
||||
runner.cur_dataloader.sampler.set_epoch = Mock()
|
||||
hook.before_train_epoch(runner)
|
||||
runner.cur_dataloader.sampler.set_epoch.assert_called()
|
||||
# Test batch sampler
|
||||
runner = Mock()
|
||||
runner.data_loader = Mock()
|
||||
runner.data_loader.sampler = Mock(spec_set=True)
|
||||
runner.data_loader.batch_sampler = Mock()
|
||||
runner.data_loader.batch_sampler.sampler = Mock()
|
||||
runner.data_loader.batch_sampler.sampler.set_epoch = Mock()
|
||||
hook.before_epoch(runner)
|
||||
runner.data_loader.batch_sampler.sampler.set_epoch.assert_called()
|
||||
runner.cur_dataloader = Mock()
|
||||
runner.cur_dataloader.sampler = Mock(spec_set=True)
|
||||
runner.cur_dataloader.batch_sampler = Mock()
|
||||
runner.cur_dataloader.batch_sampler.sampler = Mock()
|
||||
runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock()
|
||||
hook.before_train_epoch(runner)
|
||||
runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called()
|
||||
|
|
|
@ -10,4 +10,4 @@ class TestSyncBuffersHook:
|
|||
Runner = Mock()
|
||||
Runner.model = Mock()
|
||||
Hook = SyncBuffersHook()
|
||||
Hook.after_epoch(Runner)
|
||||
Hook._after_epoch(Runner)
|
||||
|
|
Loading…
Reference in New Issue