[Refactor] Refactor the interfaces of Hook and its subclassed (#117)

* Fix hook

* Fix

* Fix docs

* FIx

* Fix

* Fix as comment

* update

* Fix hook

* Fix hook

* Fix hook

* Fix itertimerhook

* Fix iter_timer_hook

* Fix

* Fix

* fix logger hook

* Fix loggerhook

* update cur_dataloader

* Fix docstring

* Fix docstring

* Fix as commet

* Fix as commet

* Fix as comment

* rename is_last_epoch, enhance and add after_val before_val .etc

* fix typo in docstring

* remove resolved TODO

* refactor docstring
pull/127/head
Mashiro 2022-03-13 16:48:09 +08:00 committed by GitHub
parent 755f8b5b59
commit a7961407e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 241 additions and 192 deletions

View File

@ -119,9 +119,8 @@ class CheckpointHook(Hook):
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
if self.every_n_epochs(runner, self.interval) or (
self.save_last and self.is_last_train_epoch(runner)):
runner.logger.info(f'Saving checkpoint at \
{runner.epoch + 1} epochs')
if self.sync_buffer:
@ -187,9 +186,8 @@ class CheckpointHook(Hook):
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
if self.every_n_iters(runner, self.interval) or \
(self.save_last and self.is_last_iter(runner, mode='train')):
runner.logger.info(f'Saving checkpoint at \
{runner.iter + 1} iterations')
if self.sync_buffer:

View File

@ -30,16 +30,16 @@ class EmptyCacheHook(Hook):
before_epoch: bool = False,
after_epoch: bool = True,
after_iter: bool = False) -> None:
self._before_epoch = before_epoch
self._after_epoch = after_epoch
self._after_iter = after_iter
self._do_before_epoch = before_epoch
self._do_after_epoch = after_epoch
self._do_after_iter = after_iter
def after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None)\
-> None:
def _after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[dict,
Sequence[BaseDataSample]]] = None,
mode: str = 'train') -> None:
"""Empty cache after an iteration.
Args:
@ -48,24 +48,27 @@ class EmptyCacheHook(Hook):
from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from model.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._after_iter:
if self._do_after_iter:
torch.cuda.empty_cache()
def before_epoch(self, runner) -> None:
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache before an epoch.
Args:
runner (Runner): The runner of the training process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._before_epoch:
if self._do_before_epoch:
torch.cuda.empty_cache()
def after_epoch(self, runner) -> None:
def _after_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache after an epoch.
Args:
runner (Runner): The runner of the training process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._after_epoch:
if self._do_after_epoch:
torch.cuda.empty_cache()

View File

@ -16,20 +16,20 @@ class Hook:
def before_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before the training process.
operations before the training validation or testing process.
Args:
runner (Runner): The runner of the training/validation/testing
runner (Runner): The runner of the training, validation or testing
process.
"""
pass
def after_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after the training process.
operations before the training validation or testing process.
Args:
runner (Runner): The runner of the training/validation/testing
runner (Runner): The runner of the training, validation or testing
process.
"""
pass
@ -54,7 +54,7 @@ class Hook:
def before_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before val.
operations before validation.
Args:
runner (Runner): The runner of the validation process.
@ -63,7 +63,7 @@ class Hook:
def after_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after val.
operations after validation.
Args:
runner (Runner): The runner of the validation process.
@ -72,7 +72,7 @@ class Hook:
def before_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before test.
operations before testing.
Args:
runner (Runner): The runner of the testing process.
@ -81,67 +81,21 @@ class Hook:
def after_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after test.
operations after testing.
Args:
runner (Runner): The runner of the testing process.
"""
pass
def before_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before each epoch.
Args:
runner (Runner): The runner of the training process.
"""
pass
def after_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
Args:
runner (Runner): The runner of the training process.
"""
pass
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each iter.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
pass
def after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
-> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults
to None.
"""
pass
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any
operations before saving the checkpoint.
Args:
runner (Runner): The runner of the training process.
checkpoints (dict): Model's checkpoint.
runner (Runner): The runner of the training, validation or testing
process.
checkpoint (dict): Model's checkpoint.
"""
pass
@ -150,8 +104,9 @@ class Hook:
operations after loading the checkpoint.
Args:
runner (Runner): The runner of the training process.
checkpoints (dict): Model's checkpoint.
runner (Runner): The runner of the training, validation or testing
process.
checkpoint (dict): Model's checkpoint.
"""
pass
@ -162,25 +117,25 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
"""
self.before_epoch(runner)
self._before_epoch(runner, mode='train')
def before_val_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before each validation epoch.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the validation process.
"""
self.before_epoch(runner)
self._before_epoch(runner, mode='val')
def before_test_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before each test epoch.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the testing process.
"""
self.before_epoch(runner)
self._before_epoch(runner, mode='test')
def after_train_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -189,25 +144,25 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
"""
self.after_epoch(runner)
self._after_epoch(runner, mode='train')
def after_val_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after each validation epoch.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the validation process.
"""
self.after_epoch(runner)
self._after_epoch(runner, mode='val')
def after_test_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after each test epoch.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the testing process.
"""
self.after_epoch(runner)
self._after_epoch(runner, mode='test')
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
@ -218,29 +173,29 @@ class Hook:
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
self._before_iter(runner, data_batch=data_batch, mode='train')
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each validation iteration.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the validation process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
self._before_iter(runner, data_batch=data_batch, mode='val')
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each test iteration.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the testing process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
self._before_iter(runner, data_batch=data_batch, mode='test')
def after_train_iter(self,
runner,
@ -256,7 +211,8 @@ class Hook:
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self.after_iter(runner, data_batch=None, outputs=None)
self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='train')
def after_val_iter(self,
runner,
@ -267,13 +223,14 @@ class Hook:
operations after each validation iteration.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the validation process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from
model. Defaults to None.
"""
self.after_iter(runner, data_batch=None, outputs=None)
self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='val')
def after_test_iter(
self,
@ -284,48 +241,108 @@ class Hook:
operations after each test iteration.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self.after_iter(runner, data_batch=None, outputs=None)
self._after_iter(
runner, data_batch=data_batch, outputs=outputs, mode='test')
def every_n_epochs(self, runner, n: int) -> bool:
"""Test whether or not current epoch can be evenly divided by n.
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
operations before each epoch.
Args:
runner (Runner): The runner of the training process.
n (int): Whether or not current epoch can be evenly divided by n.
runner (Runner): The runner of the training, validation or testing
process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
def _after_epoch(self, runner, mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
Args:
runner (Runner): The runner of the training, validation or testing
process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
def _before_iter(self,
runner,
data_batch: DATA_BATCH = None,
mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
operations before each iter.
Args:
runner (Runner): The runner of the training, validation or testing
process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
def _after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[Sequence[BaseDataSample],
dict]] = None,
mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
Args:
runner (Runner): The runner of the training, validation or testing
process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
def every_n_epochs(self, runner, n: int) -> bool:
"""Test whether current epoch can be evenly divided by n.
Args:
runner (Runner): The runner of the training, validation or testing
process.
n (int): Whether current epoch can be evenly divided by n.
Returns:
bool: whether or not current epoch can be evenly divided by n.
bool: Whether current epoch can be evenly divided by n.
"""
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n: int) -> bool:
"""Test whether or not current inner iteration can be evenly divided by
n.
"""Test whether current inner iteration can be evenly divided by n.
Args:
runner (Runner): The runner of the training process.
n (int): Whether or not current inner iteration can be evenly
runner (Runner): The runner of the training, validation or testing
process.
n (int): Whether current inner iteration can be evenly
divided by n.
Returns:
bool: whether or not current inner iteration can be evenly
bool: Whether current inner iteration can be evenly
divided by n.
"""
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n: int) -> bool:
"""Test whether or not current iteration can be evenly divided by n.
"""Test whether current iteration can be evenly divided by n.
Args:
runner (Runner): The runner of the training process.
n (int): Whether or not current iteration can be
evenly divided by n.
runner (Runner): The runner of the training, validation or testing
process.
n (int): Whether current iteration can be evenly divided by n.
Returns:
bool: Return True if the current iteration can be evenly divided
@ -334,35 +351,46 @@ class Hook:
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner) -> bool:
"""Check whether the current epoch reaches the `max_epochs` or not.
"""Check whether the current iteration reaches the last iteration of
current dataloader.
Args:
runner (Runner): The runner of the training, validation or testing
process.
Returns:
bool: Whether reaches the end of current epoch or not.
"""
return runner.inner_iter + 1 == len(runner.cur_dataloader)
def is_last_train_epoch(self, runner) -> bool:
"""Test whether current epoch is the last train epoch.
Args:
runner (Runner): The runner of the training process.
Returns:
bool: whether the end of current epoch or not.
bool: Whether reaches the end of training epoch.
"""
return runner.inner_iter + 1 == len(runner.data_loader)
return runner.epoch + 1 == runner.train_loop.max_epochs
def is_last_epoch(self, runner) -> bool:
"""Test whether or not current epoch is the last epoch.
def is_last_iter(self, runner, mode='train') -> bool:
"""Test whether current iteration is the last iteration.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the training, validation or testing
process.
Returns:
bool: bool: Return True if the current epoch reaches the
`max_epochs`, otherwise False.
bool: Whether current iteration is the last iteration.
mode (str): Current mode of runner. Defaults to 'train'.
"""
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner) -> bool:
"""Test whether or not current epoch is the last iteration.
Args:
runner (Runner): The runner of the training process.
Returns:
bool: whether or not current iteration is the last iteration.
"""
return runner.iter + 1 == runner._max_iters
if mode == 'train':
return runner.iter + 1 == runner.train_loop.max_iters
elif mode == 'val':
return runner.iter + 1 == runner.val_loop.max_iters
elif mode == 'test':
return runner.iter + 1 == runner.test_loop.max_iters
else:
raise ValueError('mode should be train, val or test but got'
f'{mode}')

View File

@ -18,30 +18,37 @@ class IterTimerHook(Hook):
priority = 'NORMAL'
def before_epoch(self, runner) -> None:
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Record time flag before start a epoch.
Args:
runner (Runner): The runner of the training process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
self.t = time.time()
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
def _before_iter(self,
runner,
data_batch: DATA_BATCH = None,
mode: str = 'train') -> None:
"""Logging time for loading data and update the time flag.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# TODO: update for new logging system
runner.log_buffer.update({'data_time': time.time() - self.t})
runner.message_hub.update_log(f'{mode}/data_time',
time.time() - self.t)
def after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
def _after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None,
mode: str = 'train') \
-> None:
"""Logging time for a iteration and update the time flag.
@ -51,7 +58,9 @@ class IterTimerHook(Hook):
from dataloader. Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults
to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# TODO: update for new logging system
runner.log_buffer.update({'time': time.time() - self.t})
runner.message_hub.update_log(f'{mode}/time', time.time() - self.t)
self.t = time.time()

View File

@ -264,14 +264,15 @@ class LoggerHook(Hook):
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{cur_epoch}]' \
f'[{cur_iter}/{len(runner.data_loader)}]\t'
f'[{cur_iter}/{len(runner.cur_dataloader)}]\t'
else:
log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t'
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t'
log_str += f'{lr_momentum_str}, '
# Calculate eta time.
self.time_sec_tot += (tag['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_sec = time_sec_avg * (
runner.train_loop.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += f'eta: {eta_str}, '
log_str += f'time: {tag["time"]:.3f}, ' \
@ -302,7 +303,7 @@ class LoggerHook(Hook):
"""
tag = self._collect_info(runner, 'val')
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
eval_iter = len(runner.data_loader)
eval_iter = len(runner.cur_dataloader)
cur_iter = self._get_iter(runner)
cur_epoch = self._get_epoch(runner, 'val')
# val/test time

View File

@ -14,15 +14,15 @@ class DistSamplerSeedHook(Hook):
priority = 'NORMAL'
def before_epoch(self, runner) -> None:
def before_train_epoch(self, runner, mode: str = 'train') -> None:
"""Set the seed for sampler and batch_sampler.
Args:
runner (Runner): The runner of the training process.
"""
if hasattr(runner.data_loader.sampler, 'set_epoch'):
if hasattr(runner.cur_dataloader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
runner.data_loader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
runner.cur_dataloader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'):
# batch sampler in pytorch warps the sampler as its attributes.
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch)

View File

@ -89,7 +89,7 @@ class SyncBuffersHook(Hook):
def __init__(self) -> None:
self.distributed = dist.IS_DIST
def after_epoch(self, runner) -> None:
def after_train_epoch(self, runner) -> None:
"""All-reduce model buffers at the end of each epoch.
Args:

View File

@ -9,6 +9,6 @@ class TestEmptyCacheHook:
def test_emtpy_cache_hook(self):
Hook = EmptyCacheHook(True, True, True)
Runner = Mock()
Hook.after_iter(Runner)
Hook.before_epoch(Runner)
Hook.after_epoch(Runner)
Hook._after_iter(Runner)
Hook._before_epoch(Runner)
Hook._after_epoch(Runner)

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
import pytest
from mmengine.hooks import Hook
@ -19,25 +21,25 @@ class TestHook:
def test_before_epoch(self):
hook = Hook()
runner = Mock()
hook.before_epoch(runner)
hook._before_epoch(runner)
def test_after_epoch(self):
hook = Hook()
runner = Mock()
hook.after_epoch(runner)
hook._after_epoch(runner)
def test_before_iter(self):
hook = Hook()
runner = Mock()
data_batch = {}
hook.before_iter(runner, data_batch)
hook._before_iter(runner, data_batch)
def test_after_iter(self):
hook = Hook()
runner = Mock()
data_batch = {}
outputs = {}
hook.after_iter(runner, data_batch, outputs)
hook._after_iter(runner, data_batch, outputs)
def test_before_save_checkpoint(self):
hook = Hook()
@ -161,7 +163,8 @@ class TestHook:
# last inner iter
runner.inner_iter = 1
runner.data_loader.__len__ = Mock(return_value=2)
runner.cur_dataloader.__len__ = Mock(return_value=2)
runner.cur_dataloader.__len__ = Mock(return_value=2)
return_val = hook.end_of_epoch(runner)
assert return_val
@ -170,19 +173,19 @@ class TestHook:
return_val = hook.end_of_epoch(runner)
assert not return_val
def test_is_last_epoch(self):
def test_is_last_train_epoch(self):
hook = Hook()
runner = Mock()
# last epoch
runner.epoch = 1
runner._max_epochs = 2
return_val = hook.is_last_epoch(runner)
runner.train_loop.max_epochs = 2
return_val = hook.is_last_train_epoch(runner)
assert return_val
# not the last epoch
runner.epoch = 0
return_val = hook.is_last_epoch(runner)
runner.train_loop.max_epochs = 0
return_val = hook.is_last_train_epoch(runner)
assert not return_val
def test_is_last_iter(self):
@ -191,11 +194,18 @@ class TestHook:
# last iter
runner.iter = 1
runner._max_iters = 2
runner.train_loop.max_iters = 2
return_val = hook.is_last_iter(runner)
assert return_val
# not the last iter
runner.iter = 0
return_val = hook.is_last_iter(runner)
runner.val_loop.max_iters = 0
return_val = hook.is_last_iter(runner, mode='val')
assert not return_val
runner.test_loop.max_iters = 0
return_val = hook.is_last_iter(runner, mode='test')
assert not return_val
with pytest.raises(ValueError):
hook.is_last_iter(runner, mode='error_mode')

View File

@ -9,21 +9,21 @@ class TestIterTimerHook:
def test_before_epoch(self):
Hook = IterTimerHook()
Runner = Mock()
Hook.before_epoch(Runner)
Hook._before_epoch(Runner)
assert isinstance(Hook.t, float)
def test_before_iter(self):
Hook = IterTimerHook()
Runner = Mock()
Runner.log_buffer = dict()
Hook.before_epoch(Runner)
Hook.before_iter(Runner)
assert 'data_time' in Runner.log_buffer
Hook._before_epoch(Runner)
Hook._before_iter(Runner)
Runner.message_hub.update_log.assert_called()
def test_after_iter(self):
Hook = IterTimerHook()
Runner = Mock()
Runner.log_buffer = dict()
Hook.before_epoch(Runner)
Hook.after_iter(Runner)
assert 'time' in Runner.log_buffer
Hook._before_epoch(Runner)
Hook._after_iter(Runner)
Runner.message_hub.update_log.assert_called()

View File

@ -110,7 +110,7 @@ class TestLoggerHook:
# Test end of the epoch.
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
logger_hook._log_train = MagicMock()
runner.data_loader = [0] * 5
runner.cur_dataloader = [0] * 5
runner.inner_iter = 4
logger_hook.after_train_iter(runner)
logger_hook._log_train.assert_called()
@ -155,7 +155,7 @@ class TestLoggerHook:
out, _ = capsys.readouterr()
time_avg = logger_hook.time_sec_tot / (
runner.iter + 1 - logger_hook.start_iter)
eta_second = time_avg * (runner.max_iters - runner.iter - 1)
eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_second)))
if by_epoch:
if torch.cuda.is_available():
@ -337,10 +337,10 @@ class TestLoggerHook:
def _setup_runner(self):
runner = MagicMock()
runner.epoch = 1
runner.data_loader = [0] * 5
runner.cur_dataloader = [0] * 5
runner.inner_iter = 1
runner.iter = 10
runner.max_iters = 50
runner.train_loop.max_iters = 50
logger = logging.getLogger()
logger.setLevel(logging.INFO)
for handler in logger.handlers:

View File

@ -12,17 +12,17 @@ class TestDistSamplerSeedHook:
# Test dataset sampler
runner = Mock()
runner.epoch = 1
runner.data_loader = Mock()
runner.data_loader.sampler = Mock()
runner.data_loader.sampler.set_epoch = Mock()
hook.before_epoch(runner)
runner.data_loader.sampler.set_epoch.assert_called()
runner.cur_dataloader = Mock()
runner.cur_dataloader.sampler = Mock()
runner.cur_dataloader.sampler.set_epoch = Mock()
hook.before_train_epoch(runner)
runner.cur_dataloader.sampler.set_epoch.assert_called()
# Test batch sampler
runner = Mock()
runner.data_loader = Mock()
runner.data_loader.sampler = Mock(spec_set=True)
runner.data_loader.batch_sampler = Mock()
runner.data_loader.batch_sampler.sampler = Mock()
runner.data_loader.batch_sampler.sampler.set_epoch = Mock()
hook.before_epoch(runner)
runner.data_loader.batch_sampler.sampler.set_epoch.assert_called()
runner.cur_dataloader = Mock()
runner.cur_dataloader.sampler = Mock(spec_set=True)
runner.cur_dataloader.batch_sampler = Mock()
runner.cur_dataloader.batch_sampler.sampler = Mock()
runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock()
hook.before_train_epoch(runner)
runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called()

View File

@ -10,4 +10,4 @@ class TestSyncBuffersHook:
Runner = Mock()
Runner.model = Mock()
Hook = SyncBuffersHook()
Hook.after_epoch(Runner)
Hook._after_epoch(Runner)