[Enhance] Add RuntimeInfoHook to update runtime information. (#254)

* [Enhance] Add RuntimeInfoHook to update runtime information.

* move lr to runtime info

* docstring

* resolve comments

* update ut and doc
pull/265/head
RangiLyu 2022-05-26 14:35:37 +08:00 committed by GitHub
parent 4cbbbc0c31
commit 4705e1fe3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 269 additions and 71 deletions

View File

@ -194,6 +194,7 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
| 名称 | 用途 | 优先级 | | 名称 | 用途 | 优先级 |
| :-----------------: | :-------------------------: | :---------------: | | :-----------------: | :-------------------------: | :---------------: |
| RuntimeInfoHook | 向 message hub 更新运行时信息 | VERY_HIGH (10) |
| OptimizerHook | 反向传播以及参数更新 | HIGH (30) | | OptimizerHook | 反向传播以及参数更新 | HIGH (30) |
| DistSamplerSeedHook | 确保分布式 Sampler 的 shuffle 生效 | NORMAL (50) | | DistSamplerSeedHook | 确保分布式 Sampler 的 shuffle 生效 | NORMAL (50) |
| SyncBuffersHook | 同步模型的 buffer | NORMAL (50) | | SyncBuffersHook | 同步模型的 buffer | NORMAL (50) |
@ -219,12 +220,13 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
from mmengine import Runner from mmengine import Runner
default_hooks = dict( default_hooks = dict(
optimizer=dict(type='OptimizerHook'), runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook', optimizer=dict(type='OptimizerHook', grad_clip=None),
timer=dict(type='IterTimerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'), sampler_seed=dict(type='DistSamplerSeedHook'),
logger=dict(type='TextLoggerHook'), logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook')), param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1) checkpoint=dict(type='CheckpointHook', interval=1),
) )
custom_hooks = [ custom_hooks = [
@ -381,6 +383,11 @@ config = dict(type='EmptyCacheHook', before_epoch=False, after_epoch=True, after
config = dict(type='SyncBuffersHook') config = dict(type='SyncBuffersHook')
``` ```
### RuntimeInfoHook
`RuntimeInfoHook` 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等更新至 message hub 中,
以便其他无法访问执行器的模块能够获取到这些信息。
## 添加自定义钩子 ## 添加自定义钩子
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。 如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。

View File

@ -8,11 +8,12 @@ from .logger_hook import LoggerHook
from .naive_visualization_hook import NaiveVisualizationHook from .naive_visualization_hook import NaiveVisualizationHook
from .optimizer_hook import OptimizerHook from .optimizer_hook import OptimizerHook
from .param_scheduler_hook import ParamSchedulerHook from .param_scheduler_hook import ParamSchedulerHook
from .runtime_info_hook import RuntimeInfoHook
from .sampler_seed_hook import DistSamplerSeedHook from .sampler_seed_hook import DistSamplerSeedHook
from .sync_buffer_hook import SyncBuffersHook from .sync_buffer_hook import SyncBuffersHook
__all__ = [ __all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
'LoggerHook', 'NaiveVisualizationHook', 'EMAHook' 'LoggerHook', 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
] ]

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import itertools import itertools
from typing import Optional from typing import Dict, Optional
from mmengine.model import is_model_wrapper from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS from mmengine.registry import HOOKS, MODELS
@ -23,6 +23,8 @@ class EMAHook(Hook):
Defaults to 'ExponentialMovingAverage' Defaults to 'ExponentialMovingAverage'
""" """
priority = 'NORMAL'
def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs): def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
self.ema_cfg = dict(type=ema_type, **kwargs) self.ema_cfg = dict(type=ema_type, **kwargs)
@ -48,7 +50,9 @@ class EMAHook(Hook):
validation.""" validation."""
self._swap_ema_parameters() self._swap_ema_parameters()
def after_val_epoch(self, runner) -> None: def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""We recover source model's parameter from ema model after """We recover source model's parameter from ema model after
validation.""" validation."""
self._swap_ema_parameters() self._swap_ema_parameters()
@ -58,7 +62,9 @@ class EMAHook(Hook):
test.""" test."""
self._swap_ema_parameters() self._swap_ema_parameters()
def after_test_epoch(self, runner) -> None: def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""We recover source model's parameter from ema model after test.""" """We recover source model's parameter from ema model after test."""
self._swap_ema_parameters() self._swap_ema_parameters()

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
@ -146,21 +146,31 @@ class Hook:
""" """
self._after_epoch(runner, mode='train') self._after_epoch(runner, mode='train')
def after_val_epoch(self, runner) -> None: def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each validation epoch. operations after each validation epoch.
Args: Args:
runner (Runner): The runner of the validation process. runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
""" """
self._after_epoch(runner, mode='val') self._after_epoch(runner, mode='val')
def after_test_epoch(self, runner) -> None: def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each test epoch. operations after each test epoch.
Args: Args:
runner (Runner): The runner of the testing process. runner (Runner): The runner of the testing process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results.
""" """
self._after_epoch(runner, mode='test') self._after_epoch(runner, mode='test')

View File

@ -2,7 +2,7 @@
import os import os
import os.path as osp import os.path as osp
from pathlib import Path from pathlib import Path
from typing import Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from mmengine.fileio import FileClient from mmengine.fileio import FileClient
@ -188,11 +188,17 @@ class LoggerHook(Hook):
runner, batch_idx, 'test') runner, batch_idx, 'test')
runner.logger.info(log_str) runner.logger.info(log_str)
def after_val_epoch(self, runner) -> None: def after_val_epoch(self,
"""Record logs after validation epoch. runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each validation epoch.
Args: Args:
runner (Runner): The runner of the validation process. runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
""" """
tag, log_str = runner.log_processor.get_log_after_epoch( tag, log_str = runner.log_processor.get_log_after_epoch(
runner, len(runner.val_dataloader), 'val') runner, len(runner.val_dataloader), 'val')
@ -200,11 +206,17 @@ class LoggerHook(Hook):
runner.visualizer.add_scalars( runner.visualizer.add_scalars(
tag, step=runner.iter, file_path=self.json_log_path) tag, step=runner.iter, file_path=self.json_log_path)
def after_test_epoch(self, runner) -> None: def after_test_epoch(self,
"""Record logs after testing epoch. runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each test epoch.
Args: Args:
runner (Runner): The runner of the testing process. runner (Runner): The runner of the testing process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results.
""" """
_, log_str = runner.log_processor.get_log_after_epoch( _, log_str = runner.log_processor.get_log_after_epoch(
runner, len(runner.test_dataloader), 'test') runner, len(runner.test_dataloader), 'test')

View File

@ -84,9 +84,6 @@ class OptimizerHook(Hook):
we keep ``outputs`` here. Defaults to None. we keep ``outputs`` here. Defaults to None.
""" """
runner.optimizer.zero_grad() runner.optimizer.zero_grad()
runner.message_hub.update_scalar(
'train/lr', runner.optimizer.param_groups[0]['lr'])
if self.detect_anomalous_params: if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner) self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward() runner.outputs['loss'].backward()

View File

@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
class RuntimeInfoHook(Hook):
"""A hook that updates runtime information into message hub.
E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the
training state. Components that cannot access the runner can get runtime
information through the message hub.
"""
priority = 'VERY_HIGH'
def before_run(self, runner) -> None:
"""Initialize runtime information."""
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
runner.message_hub.update_info('max_iters', runner.max_iters)
def before_train(self, runner) -> None:
"""Update resumed training state."""
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch."""
runner.message_hub.update_info('epoch', runner.epoch)
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""Update current iter and learning rate information before every
iteration."""
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_scalar(
'train/lr', runner.optimizer.param_groups[0]['lr'])
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Update ``log_vars`` in model outputs every iteration."""
if outputs is not None:
for key, value in outputs['log_vars'].items():
runner.message_hub.update_scalar(f'train/{key}', value)
def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each validation epoch.
Args:
runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if metrics is not None:
for key, value in metrics.items():
runner.message_hub.update_scalar(f'val/{key}', value)
def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each test epoch.
Args:
runner (Runner): The runner of the testing process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if metrics is not None:
for key, value in metrics.items():
runner.message_hub.update_scalar(f'test/{key}', value)

View File

@ -80,8 +80,6 @@ class EpochBasedTrainLoop(BaseLoop):
self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train_epoch')
self._epoch += 1 self._epoch += 1
# To allow components that cannot access runner to get current epoch.
self.runner.message_hub.update_info('epoch', self._epoch)
def run_iter(self, idx, data_batch: Sequence[dict]) -> None: def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one min-batch. """Iterate one min-batch.
@ -94,10 +92,6 @@ class EpochBasedTrainLoop(BaseLoop):
# outputs should be a dict containing one or multiple loss tensors # outputs should be a dict containing one or multiple loss tensors
self.runner.outputs = self.runner.model(data_batch, return_loss=True) self.runner.outputs = self.runner.model(data_batch, return_loss=True)
# TODO, should move to LoggerHook
for key, value in self.runner.outputs['log_vars'].items():
self.runner.message_hub.update_scalar(f'train/{key}', value)
self.runner.call_hook( self.runner.call_hook(
'after_train_iter', 'after_train_iter',
batch_idx=idx, batch_idx=idx,
@ -105,9 +99,6 @@ class EpochBasedTrainLoop(BaseLoop):
outputs=self.runner.outputs) outputs=self.runner.outputs)
self._iter += 1 self._iter += 1
# To allow components that cannot access runner to get current
# iteration.
self.runner.message_hub.update_info('iter', self._iter)
@LOOPS.register_module() @LOOPS.register_module()
@ -188,19 +179,12 @@ class IterBasedTrainLoop(BaseLoop):
# outputs should be a dict containing loss tensor # outputs should be a dict containing loss tensor
self.runner.outputs = self.runner.model(data_batch, return_loss=True) self.runner.outputs = self.runner.model(data_batch, return_loss=True)
# TODO
for key, value in self.runner.outputs['log_vars'].items():
self.runner.message_hub.update_scalar(f'train/{key}', value)
self.runner.call_hook( self.runner.call_hook(
'after_train_iter', 'after_train_iter',
batch_idx=self._iter, batch_idx=self._iter,
data_batch=data_batch, data_batch=data_batch,
outputs=self.runner.outputs) outputs=self.runner.outputs)
self._iter += 1 self._iter += 1
# To allow components that cannot access runner to get current
# iteration.
self.runner.message_hub.update_info('iter', self._iter)
@LOOPS.register_module() @LOOPS.register_module()
@ -247,10 +231,8 @@ class ValLoop(BaseLoop):
# compute metrics # compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
for key, value in metrics.items():
self.runner.message_hub.update_scalar(f'val/{key}', value)
self.runner.call_hook('after_val_epoch') self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val') self.runner.call_hook('after_val')
@torch.no_grad() @torch.no_grad()
@ -312,10 +294,8 @@ class TestLoop(BaseLoop):
# compute metrics # compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
for key, value in metrics.items():
self.runner.message_hub.update_scalar(f'test/{key}', value)
self.runner.call_hook('after_test_epoch') self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test') self.runner.call_hook('after_test')
@torch.no_grad() @torch.no_grad()

View File

@ -1353,6 +1353,8 @@ class Runner:
+----------------------+-------------------------+ +----------------------+-------------------------+
| Hooks | Priority | | Hooks | Priority |
+======================+=========================+ +======================+=========================+
| RuntimeInfoHook | VERY_HIGH (10) |
+----------------------+-------------------------+
| OptimizerHook | HIGH (30) | | OptimizerHook | HIGH (30) |
+----------------------+-------------------------+ +----------------------+-------------------------+
| IterTimerHook | NORMAL (40) | | IterTimerHook | NORMAL (40) |
@ -1370,6 +1372,7 @@ class Runner:
default:: default::
default_hooks = dict( default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None), optimizer=dict(type='OptimizerHook', grad_clip=None),
timer=dict(type='IterTimerHook'), timer=dict(type='IterTimerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'), sampler_seed=dict(type='DistSamplerSeedHook'),
@ -1392,6 +1395,7 @@ class Runner:
to be registered. to be registered.
""" """
default_hooks: dict = dict( default_hooks: dict = dict(
runtime_info=dict(type='RuntimeInfoHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None), optimizer=dict(type='OptimizerHook', grad_clip=None),
timer=dict(type='IterTimerHook'), timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'), logger=dict(type='LoggerHook'),

View File

@ -74,12 +74,12 @@ class TestHook:
def test_after_val_epoch(self): def test_after_val_epoch(self):
hook = Hook() hook = Hook()
runner = Mock() runner = Mock()
hook.after_val_epoch(runner) hook.after_val_epoch(runner, {})
def test_after_test_epoch(self): def test_after_test_epoch(self):
hook = Hook() hook = Hook()
runner = Mock() runner = Mock()
hook.after_test_epoch(runner) hook.after_test_epoch(runner, {})
def test_before_train_iter(self): def test_before_train_iter(self):
hook = Hook() hook = Hook()

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
from mmengine.hooks import RuntimeInfoHook
from mmengine.logging import MessageHub
class TestRuntimeInfoHook(TestCase):
def test_before_run(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_run')
runner = Mock()
runner.epoch = 3
runner.iter = 30
runner.max_epochs = 4
runner.max_iters = 40
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_run(runner)
self.assertEqual(message_hub.get_info('epoch'), 3)
self.assertEqual(message_hub.get_info('iter'), 30)
self.assertEqual(message_hub.get_info('max_epochs'), 4)
self.assertEqual(message_hub.get_info('max_iters'), 40)
def test_before_train(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train')
runner = Mock()
runner.epoch = 7
runner.iter = 71
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_train(runner)
self.assertEqual(message_hub.get_info('epoch'), 7)
self.assertEqual(message_hub.get_info('iter'), 71)
def test_before_train_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_epoch')
runner = Mock()
runner.epoch = 9
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_train_epoch(runner)
self.assertEqual(message_hub.get_info('epoch'), 9)
def test_before_train_iter(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
runner = Mock()
runner.iter = 9
runner.optimizer.param_groups = [{'lr': 0.01}]
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
def test_after_train_iter(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_train_iter')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_train_iter(
runner,
batch_idx=2,
data_batch=None,
outputs={'log_vars': {
'loss_cls': 1.111
}})
self.assertEqual(
message_hub.get_scalar('train/loss_cls').current(), 1.111)
def test_after_val_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_val_epoch')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_val_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8)
def test_after_test_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_test_epoch')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.after_test_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8)

View File

@ -218,11 +218,14 @@ class TestRunner(TestCase):
test_cfg=dict(), test_cfg=dict(),
custom_hooks=[], custom_hooks=[],
default_hooks=dict( default_hooks=dict(
timer=dict(type='IterTimerHook'), runtime_info=dict(type='RuntimeInfoHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
logger=dict(type='LoggerHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None), optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook')), timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=1, by_epoch=True),
sampler_seed=dict(type='DistSamplerSeedHook')),
launcher='none', launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')), env_cfg=dict(dist_cfg=dict(backend='nccl')),
) )
@ -235,11 +238,13 @@ class TestRunner(TestCase):
num_workers=0) num_workers=0)
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
self.iter_based_cfg.default_hooks = dict( self.iter_based_cfg.default_hooks = dict(
timer=dict(type='IterTimerHook'), runtime_info=dict(type='RuntimeInfoHook'),
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
logger=dict(type='LoggerHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None), optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook')) timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
sampler_seed=dict(type='DistSamplerSeedHook'))
def tearDown(self): def tearDown(self):
shutil.rmtree(self.temp_dir) shutil.rmtree(self.temp_dir)
@ -868,9 +873,6 @@ class TestRunner(TestCase):
assert isinstance(runner.train_loop, EpochBasedTrainLoop) assert isinstance(runner.train_loop, EpochBasedTrainLoop)
assert runner.iter == runner.message_hub.get_info('iter')
assert runner.epoch == runner.message_hub.get_info('epoch')
for result, target, in zip(epoch_results, epoch_targets): for result, target, in zip(epoch_results, epoch_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
for result, target, in zip(iter_results, iter_targets): for result, target, in zip(iter_results, iter_targets):
@ -903,7 +905,6 @@ class TestRunner(TestCase):
runner.train() runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop) assert isinstance(runner.train_loop, IterBasedTrainLoop)
assert runner.iter == runner.message_hub.get_info('iter')
self.assertEqual(len(epoch_results), 1) self.assertEqual(len(epoch_results), 1)
self.assertEqual(epoch_results[0], 0) self.assertEqual(epoch_results[0], 0)
@ -1035,37 +1036,37 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
runner._hooks = [] runner._hooks = []
# register five hooks by default # register 7 hooks by default
runner.register_default_hooks() runner.register_default_hooks()
self.assertEqual(len(runner._hooks), 6) self.assertEqual(len(runner._hooks), 7)
# the third registered hook should be `DistSamplerSeedHook` # the third registered hook should be `DistSamplerSeedHook`
self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook)) self.assertTrue(isinstance(runner._hooks[3], DistSamplerSeedHook))
# the fifth registered hook should be `ParamSchedulerHook` # the fifth registered hook should be `ParamSchedulerHook`
self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook)) self.assertTrue(isinstance(runner._hooks[5], ParamSchedulerHook))
runner._hooks = [] runner._hooks = []
# remove `ParamSchedulerHook` from default hooks # remove `ParamSchedulerHook` from default hooks
runner.register_default_hooks(hooks=dict(timer=None)) runner.register_default_hooks(hooks=dict(timer=None))
self.assertEqual(len(runner._hooks), 5) self.assertEqual(len(runner._hooks), 6)
# `ParamSchedulerHook` was popped so the forth is `CheckpointHook` # `ParamSchedulerHook` was popped so the fifth is `CheckpointHook`
self.assertTrue(isinstance(runner._hooks[4], CheckpointHook)) self.assertTrue(isinstance(runner._hooks[5], CheckpointHook))
# add a new default hook # add a new default hook
runner._hooks = [] runner._hooks = []
runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook'))) runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
self.assertEqual(len(runner._hooks), 7) self.assertEqual(len(runner._hooks), 8)
self.assertTrue(isinstance(runner._hooks[6], ToyHook)) self.assertTrue(isinstance(runner._hooks[7], ToyHook))
def test_custom_hooks(self): def test_custom_hooks(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_custom_hooks' cfg.experiment_name = 'test_custom_hooks'
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
self.assertEqual(len(runner._hooks), 6) self.assertEqual(len(runner._hooks), 7)
custom_hooks = [dict(type='ToyHook')] custom_hooks = [dict(type='ToyHook')]
runner.register_custom_hooks(custom_hooks) runner.register_custom_hooks(custom_hooks)
self.assertEqual(len(runner._hooks), 7) self.assertEqual(len(runner._hooks), 8)
self.assertTrue(isinstance(runner._hooks[6], ToyHook)) self.assertTrue(isinstance(runner._hooks[7], ToyHook))
def test_register_hooks(self): def test_register_hooks(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
@ -1075,9 +1076,9 @@ class TestRunner(TestCase):
runner._hooks = [] runner._hooks = []
custom_hooks = [dict(type='ToyHook')] custom_hooks = [dict(type='ToyHook')]
runner.register_hooks(custom_hooks=custom_hooks) runner.register_hooks(custom_hooks=custom_hooks)
# six default hooks + custom hook (ToyHook) # 7 default hooks + custom hook (ToyHook)
self.assertEqual(len(runner._hooks), 7) self.assertEqual(len(runner._hooks), 8)
self.assertTrue(isinstance(runner._hooks[6], ToyHook)) self.assertTrue(isinstance(runner._hooks[7], ToyHook))
def test_custom_loop(self): def test_custom_loop(self):
# test custom loop with additional hook # test custom loop with additional hook