[Enhance] Add RuntimeInfoHook to update runtime information. (#254)
* [Enhance] Add RuntimeInfoHook to update runtime information. * move lr to runtime info * docstring * resolve comments * update ut and docpull/265/head
parent
4cbbbc0c31
commit
4705e1fe3d
|
@ -194,6 +194,7 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
|
|||
|
||||
| 名称 | 用途 | 优先级 |
|
||||
| :-----------------: | :-------------------------: | :---------------: |
|
||||
| RuntimeInfoHook | 向 message hub 更新运行时信息 | VERY_HIGH (10) |
|
||||
| OptimizerHook | 反向传播以及参数更新 | HIGH (30) |
|
||||
| DistSamplerSeedHook | 确保分布式 Sampler 的 shuffle 生效 | NORMAL (50) |
|
||||
| SyncBuffersHook | 同步模型的 buffer | NORMAL (50) |
|
||||
|
@ -219,12 +220,13 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
|
|||
from mmengine import Runner
|
||||
|
||||
default_hooks = dict(
|
||||
optimizer=dict(type='OptimizerHook'),
|
||||
timer=dict(type='IterTimerHook',
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
logger=dict(type='TextLoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook')),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1)
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1),
|
||||
)
|
||||
|
||||
custom_hooks = [
|
||||
|
@ -381,6 +383,11 @@ config = dict(type='EmptyCacheHook', before_epoch=False, after_epoch=True, after
|
|||
config = dict(type='SyncBuffersHook')
|
||||
```
|
||||
|
||||
### RuntimeInfoHook
|
||||
|
||||
`RuntimeInfoHook` 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,
|
||||
以便其他无法访问执行器的模块能够获取到这些信息。
|
||||
|
||||
## 添加自定义钩子
|
||||
|
||||
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
|
||||
|
|
|
@ -8,11 +8,12 @@ from .logger_hook import LoggerHook
|
|||
from .naive_visualization_hook import NaiveVisualizationHook
|
||||
from .optimizer_hook import OptimizerHook
|
||||
from .param_scheduler_hook import ParamSchedulerHook
|
||||
from .runtime_info_hook import RuntimeInfoHook
|
||||
from .sampler_seed_hook import DistSamplerSeedHook
|
||||
from .sync_buffer_hook import SyncBuffersHook
|
||||
|
||||
__all__ = [
|
||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
|
||||
'LoggerHook', 'NaiveVisualizationHook', 'EMAHook'
|
||||
'LoggerHook', 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.registry import HOOKS, MODELS
|
||||
|
@ -23,6 +23,8 @@ class EMAHook(Hook):
|
|||
Defaults to 'ExponentialMovingAverage'
|
||||
"""
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
|
||||
self.ema_cfg = dict(type=ema_type, **kwargs)
|
||||
|
||||
|
@ -48,7 +50,9 @@ class EMAHook(Hook):
|
|||
validation."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_val_epoch(self, runner) -> None:
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""We recover source model's parameter from ema model after
|
||||
validation."""
|
||||
self._swap_ema_parameters()
|
||||
|
@ -58,7 +62,9 @@ class EMAHook(Hook):
|
|||
test."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_test_epoch(self, runner) -> None:
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""We recover source model's parameter from ema model after test."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
|
||||
|
@ -146,21 +146,31 @@ class Hook:
|
|||
"""
|
||||
self._after_epoch(runner, mode='train')
|
||||
|
||||
def after_val_epoch(self, runner) -> None:
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on validation dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
self._after_epoch(runner, mode='val')
|
||||
|
||||
def after_test_epoch(self, runner) -> None:
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each test epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on test dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
self._after_epoch(runner, mode='test')
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.fileio import FileClient
|
||||
|
@ -188,11 +188,17 @@ class LoggerHook(Hook):
|
|||
runner, batch_idx, 'test')
|
||||
runner.logger.info(log_str)
|
||||
|
||||
def after_val_epoch(self, runner) -> None:
|
||||
"""Record logs after validation epoch.
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on validation dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||
runner, len(runner.val_dataloader), 'val')
|
||||
|
@ -200,11 +206,17 @@ class LoggerHook(Hook):
|
|||
runner.visualizer.add_scalars(
|
||||
tag, step=runner.iter, file_path=self.json_log_path)
|
||||
|
||||
def after_test_epoch(self, runner) -> None:
|
||||
"""Record logs after testing epoch.
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each test epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on test dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
_, log_str = runner.log_processor.get_log_after_epoch(
|
||||
runner, len(runner.test_dataloader), 'test')
|
||||
|
|
|
@ -84,9 +84,6 @@ class OptimizerHook(Hook):
|
|||
we keep ``outputs`` here. Defaults to None.
|
||||
"""
|
||||
runner.optimizer.zero_grad()
|
||||
runner.message_hub.update_scalar(
|
||||
'train/lr', runner.optimizer.param_groups[0]['lr'])
|
||||
|
||||
if self.detect_anomalous_params:
|
||||
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
||||
runner.outputs['loss'].backward()
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class RuntimeInfoHook(Hook):
|
||||
"""A hook that updates runtime information into message hub.
|
||||
|
||||
E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the
|
||||
training state. Components that cannot access the runner can get runtime
|
||||
information through the message hub.
|
||||
"""
|
||||
|
||||
priority = 'VERY_HIGH'
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
"""Initialize runtime information."""
|
||||
runner.message_hub.update_info('epoch', runner.epoch)
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
||||
runner.message_hub.update_info('max_iters', runner.max_iters)
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Update resumed training state."""
|
||||
runner.message_hub.update_info('epoch', runner.epoch)
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""Update current epoch information before every epoch."""
|
||||
runner.message_hub.update_info('epoch', runner.epoch)
|
||||
|
||||
def before_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None) -> None:
|
||||
"""Update current iter and learning rate information before every
|
||||
iteration."""
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
runner.message_hub.update_scalar(
|
||||
'train/lr', runner.optimizer.param_groups[0]['lr'])
|
||||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Update ``log_vars`` in model outputs every iteration."""
|
||||
if outputs is not None:
|
||||
for key, value in outputs['log_vars'].items():
|
||||
runner.message_hub.update_scalar(f'train/{key}', value)
|
||||
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on validation dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
if metrics is not None:
|
||||
for key, value in metrics.items():
|
||||
runner.message_hub.update_scalar(f'val/{key}', value)
|
||||
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each test epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on test dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
if metrics is not None:
|
||||
for key, value in metrics.items():
|
||||
runner.message_hub.update_scalar(f'test/{key}', value)
|
|
@ -80,8 +80,6 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
self._epoch += 1
|
||||
# To allow components that cannot access runner to get current epoch.
|
||||
self.runner.message_hub.update_info('epoch', self._epoch)
|
||||
|
||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||
"""Iterate one min-batch.
|
||||
|
@ -94,10 +92,6 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
# outputs should be a dict containing one or multiple loss tensors
|
||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||
|
||||
# TODO, should move to LoggerHook
|
||||
for key, value in self.runner.outputs['log_vars'].items():
|
||||
self.runner.message_hub.update_scalar(f'train/{key}', value)
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
batch_idx=idx,
|
||||
|
@ -105,9 +99,6 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
outputs=self.runner.outputs)
|
||||
|
||||
self._iter += 1
|
||||
# To allow components that cannot access runner to get current
|
||||
# iteration.
|
||||
self.runner.message_hub.update_info('iter', self._iter)
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
|
@ -188,19 +179,12 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
# outputs should be a dict containing loss tensor
|
||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||
|
||||
# TODO
|
||||
for key, value in self.runner.outputs['log_vars'].items():
|
||||
self.runner.message_hub.update_scalar(f'train/{key}', value)
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
batch_idx=self._iter,
|
||||
data_batch=data_batch,
|
||||
outputs=self.runner.outputs)
|
||||
self._iter += 1
|
||||
# To allow components that cannot access runner to get current
|
||||
# iteration.
|
||||
self.runner.message_hub.update_info('iter', self._iter)
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
|
@ -247,10 +231,8 @@ class ValLoop(BaseLoop):
|
|||
|
||||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
for key, value in metrics.items():
|
||||
self.runner.message_hub.update_scalar(f'val/{key}', value)
|
||||
|
||||
self.runner.call_hook('after_val_epoch')
|
||||
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
||||
self.runner.call_hook('after_val')
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -312,10 +294,8 @@ class TestLoop(BaseLoop):
|
|||
|
||||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
for key, value in metrics.items():
|
||||
self.runner.message_hub.update_scalar(f'test/{key}', value)
|
||||
|
||||
self.runner.call_hook('after_test_epoch')
|
||||
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
||||
self.runner.call_hook('after_test')
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -1353,6 +1353,8 @@ class Runner:
|
|||
+----------------------+-------------------------+
|
||||
| Hooks | Priority |
|
||||
+======================+=========================+
|
||||
| RuntimeInfoHook | VERY_HIGH (10) |
|
||||
+----------------------+-------------------------+
|
||||
| OptimizerHook | HIGH (30) |
|
||||
+----------------------+-------------------------+
|
||||
| IterTimerHook | NORMAL (40) |
|
||||
|
@ -1370,6 +1372,7 @@ class Runner:
|
|||
default::
|
||||
|
||||
default_hooks = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
|
@ -1392,6 +1395,7 @@ class Runner:
|
|||
to be registered.
|
||||
"""
|
||||
default_hooks: dict = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
|
|
|
@ -74,12 +74,12 @@ class TestHook:
|
|||
def test_after_val_epoch(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
hook.after_val_epoch(runner)
|
||||
hook.after_val_epoch(runner, {})
|
||||
|
||||
def test_after_test_epoch(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
hook.after_test_epoch(runner)
|
||||
hook.after_test_epoch(runner, {})
|
||||
|
||||
def test_before_train_iter(self):
|
||||
hook = Hook()
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
from mmengine.hooks import RuntimeInfoHook
|
||||
from mmengine.logging import MessageHub
|
||||
|
||||
|
||||
class TestRuntimeInfoHook(TestCase):
|
||||
|
||||
def test_before_run(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_run')
|
||||
runner = Mock()
|
||||
runner.epoch = 3
|
||||
runner.iter = 30
|
||||
runner.max_epochs = 4
|
||||
runner.max_iters = 40
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_run(runner)
|
||||
self.assertEqual(message_hub.get_info('epoch'), 3)
|
||||
self.assertEqual(message_hub.get_info('iter'), 30)
|
||||
self.assertEqual(message_hub.get_info('max_epochs'), 4)
|
||||
self.assertEqual(message_hub.get_info('max_iters'), 40)
|
||||
|
||||
def test_before_train(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train')
|
||||
runner = Mock()
|
||||
runner.epoch = 7
|
||||
runner.iter = 71
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(message_hub.get_info('epoch'), 7)
|
||||
self.assertEqual(message_hub.get_info('iter'), 71)
|
||||
|
||||
def test_before_train_epoch(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train_epoch')
|
||||
runner = Mock()
|
||||
runner.epoch = 9
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_train_epoch(runner)
|
||||
self.assertEqual(message_hub.get_info('epoch'), 9)
|
||||
|
||||
def test_before_train_iter(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train_iter')
|
||||
runner = Mock()
|
||||
runner.iter = 9
|
||||
runner.optimizer.param_groups = [{'lr': 0.01}]
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
|
||||
self.assertEqual(message_hub.get_info('iter'), 9)
|
||||
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
|
||||
|
||||
def test_after_train_iter(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_after_train_iter')
|
||||
runner = Mock()
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
hook.after_train_iter(
|
||||
runner,
|
||||
batch_idx=2,
|
||||
data_batch=None,
|
||||
outputs={'log_vars': {
|
||||
'loss_cls': 1.111
|
||||
}})
|
||||
self.assertEqual(
|
||||
message_hub.get_scalar('train/loss_cls').current(), 1.111)
|
||||
|
||||
def test_after_val_epoch(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_after_val_epoch')
|
||||
runner = Mock()
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
hook.after_val_epoch(runner, metrics={'acc': 0.8})
|
||||
self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8)
|
||||
|
||||
def test_after_test_epoch(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_after_test_epoch')
|
||||
runner = Mock()
|
||||
runner.message_hub = message_hub
|
||||
hook = RuntimeInfoHook()
|
||||
hook.after_test_epoch(runner, metrics={'acc': 0.8})
|
||||
self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8)
|
|
@ -218,11 +218,14 @@ class TestRunner(TestCase):
|
|||
test_cfg=dict(),
|
||||
custom_hooks=[],
|
||||
default_hooks=dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1),
|
||||
logger=dict(type='LoggerHook'),
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
param_scheduler=dict(type='ParamSchedulerHook')),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', interval=1, by_epoch=True),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook')),
|
||||
launcher='none',
|
||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||
)
|
||||
|
@ -235,11 +238,13 @@ class TestRunner(TestCase):
|
|||
num_workers=0)
|
||||
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
|
||||
self.iter_based_cfg.default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
||||
logger=dict(type='LoggerHook'),
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'))
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
@ -868,9 +873,6 @@ class TestRunner(TestCase):
|
|||
|
||||
assert isinstance(runner.train_loop, EpochBasedTrainLoop)
|
||||
|
||||
assert runner.iter == runner.message_hub.get_info('iter')
|
||||
assert runner.epoch == runner.message_hub.get_info('epoch')
|
||||
|
||||
for result, target, in zip(epoch_results, epoch_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(iter_results, iter_targets):
|
||||
|
@ -903,7 +905,6 @@ class TestRunner(TestCase):
|
|||
runner.train()
|
||||
|
||||
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
||||
assert runner.iter == runner.message_hub.get_info('iter')
|
||||
|
||||
self.assertEqual(len(epoch_results), 1)
|
||||
self.assertEqual(epoch_results[0], 0)
|
||||
|
@ -1035,37 +1036,37 @@ class TestRunner(TestCase):
|
|||
runner = Runner.from_cfg(cfg)
|
||||
runner._hooks = []
|
||||
|
||||
# register five hooks by default
|
||||
# register 7 hooks by default
|
||||
runner.register_default_hooks()
|
||||
self.assertEqual(len(runner._hooks), 6)
|
||||
self.assertEqual(len(runner._hooks), 7)
|
||||
# the third registered hook should be `DistSamplerSeedHook`
|
||||
self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook))
|
||||
self.assertTrue(isinstance(runner._hooks[3], DistSamplerSeedHook))
|
||||
# the fifth registered hook should be `ParamSchedulerHook`
|
||||
self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook))
|
||||
self.assertTrue(isinstance(runner._hooks[5], ParamSchedulerHook))
|
||||
|
||||
runner._hooks = []
|
||||
# remove `ParamSchedulerHook` from default hooks
|
||||
runner.register_default_hooks(hooks=dict(timer=None))
|
||||
self.assertEqual(len(runner._hooks), 5)
|
||||
# `ParamSchedulerHook` was popped so the forth is `CheckpointHook`
|
||||
self.assertTrue(isinstance(runner._hooks[4], CheckpointHook))
|
||||
self.assertEqual(len(runner._hooks), 6)
|
||||
# `ParamSchedulerHook` was popped so the fifth is `CheckpointHook`
|
||||
self.assertTrue(isinstance(runner._hooks[5], CheckpointHook))
|
||||
|
||||
# add a new default hook
|
||||
runner._hooks = []
|
||||
runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
|
||||
self.assertEqual(len(runner._hooks), 7)
|
||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
||||
self.assertEqual(len(runner._hooks), 8)
|
||||
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||
|
||||
def test_custom_hooks(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_custom_hooks'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
self.assertEqual(len(runner._hooks), 6)
|
||||
self.assertEqual(len(runner._hooks), 7)
|
||||
custom_hooks = [dict(type='ToyHook')]
|
||||
runner.register_custom_hooks(custom_hooks)
|
||||
self.assertEqual(len(runner._hooks), 7)
|
||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
||||
self.assertEqual(len(runner._hooks), 8)
|
||||
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||
|
||||
def test_register_hooks(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
|
@ -1075,9 +1076,9 @@ class TestRunner(TestCase):
|
|||
runner._hooks = []
|
||||
custom_hooks = [dict(type='ToyHook')]
|
||||
runner.register_hooks(custom_hooks=custom_hooks)
|
||||
# six default hooks + custom hook (ToyHook)
|
||||
self.assertEqual(len(runner._hooks), 7)
|
||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
||||
# 7 default hooks + custom hook (ToyHook)
|
||||
self.assertEqual(len(runner._hooks), 8)
|
||||
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||
|
||||
def test_custom_loop(self):
|
||||
# test custom loop with additional hook
|
||||
|
|
Loading…
Reference in New Issue