[Enhance] Add RuntimeInfoHook to update runtime information. (#254)
* [Enhance] Add RuntimeInfoHook to update runtime information. * move lr to runtime info * docstring * resolve comments * update ut and docpull/265/head
parent
4cbbbc0c31
commit
4705e1fe3d
|
@ -194,6 +194,7 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
|
||||||
|
|
||||||
| 名称 | 用途 | 优先级 |
|
| 名称 | 用途 | 优先级 |
|
||||||
| :-----------------: | :-------------------------: | :---------------: |
|
| :-----------------: | :-------------------------: | :---------------: |
|
||||||
|
| RuntimeInfoHook | 向 message hub 更新运行时信息 | VERY_HIGH (10) |
|
||||||
| OptimizerHook | 反向传播以及参数更新 | HIGH (30) |
|
| OptimizerHook | 反向传播以及参数更新 | HIGH (30) |
|
||||||
| DistSamplerSeedHook | 确保分布式 Sampler 的 shuffle 生效 | NORMAL (50) |
|
| DistSamplerSeedHook | 确保分布式 Sampler 的 shuffle 生效 | NORMAL (50) |
|
||||||
| SyncBuffersHook | 同步模型的 buffer | NORMAL (50) |
|
| SyncBuffersHook | 同步模型的 buffer | NORMAL (50) |
|
||||||
|
@ -219,12 +220,13 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
|
||||||
from mmengine import Runner
|
from mmengine import Runner
|
||||||
|
|
||||||
default_hooks = dict(
|
default_hooks = dict(
|
||||||
optimizer=dict(type='OptimizerHook'),
|
runtime_info=dict(type='RuntimeInfoHook'),
|
||||||
timer=dict(type='IterTimerHook',
|
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||||
|
timer=dict(type='IterTimerHook'),
|
||||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||||
logger=dict(type='TextLoggerHook'),
|
logger=dict(type='LoggerHook'),
|
||||||
param_scheduler=dict(type='ParamSchedulerHook')),
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||||
checkpoint=dict(type='CheckpointHook', interval=1)
|
checkpoint=dict(type='CheckpointHook', interval=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_hooks = [
|
custom_hooks = [
|
||||||
|
@ -381,6 +383,11 @@ config = dict(type='EmptyCacheHook', before_epoch=False, after_epoch=True, after
|
||||||
config = dict(type='SyncBuffersHook')
|
config = dict(type='SyncBuffersHook')
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### RuntimeInfoHook
|
||||||
|
|
||||||
|
`RuntimeInfoHook` 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,
|
||||||
|
以便其他无法访问执行器的模块能够获取到这些信息。
|
||||||
|
|
||||||
## 添加自定义钩子
|
## 添加自定义钩子
|
||||||
|
|
||||||
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
|
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
|
||||||
|
|
|
@ -8,11 +8,12 @@ from .logger_hook import LoggerHook
|
||||||
from .naive_visualization_hook import NaiveVisualizationHook
|
from .naive_visualization_hook import NaiveVisualizationHook
|
||||||
from .optimizer_hook import OptimizerHook
|
from .optimizer_hook import OptimizerHook
|
||||||
from .param_scheduler_hook import ParamSchedulerHook
|
from .param_scheduler_hook import ParamSchedulerHook
|
||||||
|
from .runtime_info_hook import RuntimeInfoHook
|
||||||
from .sampler_seed_hook import DistSamplerSeedHook
|
from .sampler_seed_hook import DistSamplerSeedHook
|
||||||
from .sync_buffer_hook import SyncBuffersHook
|
from .sync_buffer_hook import SyncBuffersHook
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||||
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
|
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
|
||||||
'LoggerHook', 'NaiveVisualizationHook', 'EMAHook'
|
'LoggerHook', 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from mmengine.model import is_model_wrapper
|
from mmengine.model import is_model_wrapper
|
||||||
from mmengine.registry import HOOKS, MODELS
|
from mmengine.registry import HOOKS, MODELS
|
||||||
|
@ -23,6 +23,8 @@ class EMAHook(Hook):
|
||||||
Defaults to 'ExponentialMovingAverage'
|
Defaults to 'ExponentialMovingAverage'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
|
def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
|
||||||
self.ema_cfg = dict(type=ema_type, **kwargs)
|
self.ema_cfg = dict(type=ema_type, **kwargs)
|
||||||
|
|
||||||
|
@ -48,7 +50,9 @@ class EMAHook(Hook):
|
||||||
validation."""
|
validation."""
|
||||||
self._swap_ema_parameters()
|
self._swap_ema_parameters()
|
||||||
|
|
||||||
def after_val_epoch(self, runner) -> None:
|
def after_val_epoch(self,
|
||||||
|
runner,
|
||||||
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||||
"""We recover source model's parameter from ema model after
|
"""We recover source model's parameter from ema model after
|
||||||
validation."""
|
validation."""
|
||||||
self._swap_ema_parameters()
|
self._swap_ema_parameters()
|
||||||
|
@ -58,7 +62,9 @@ class EMAHook(Hook):
|
||||||
test."""
|
test."""
|
||||||
self._swap_ema_parameters()
|
self._swap_ema_parameters()
|
||||||
|
|
||||||
def after_test_epoch(self, runner) -> None:
|
def after_test_epoch(self,
|
||||||
|
runner,
|
||||||
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||||
"""We recover source model's parameter from ema model after test."""
|
"""We recover source model's parameter from ema model after test."""
|
||||||
self._swap_ema_parameters()
|
self._swap_ema_parameters()
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
from mmengine.data import BaseDataElement
|
from mmengine.data import BaseDataElement
|
||||||
|
|
||||||
|
@ -146,21 +146,31 @@ class Hook:
|
||||||
"""
|
"""
|
||||||
self._after_epoch(runner, mode='train')
|
self._after_epoch(runner, mode='train')
|
||||||
|
|
||||||
def after_val_epoch(self, runner) -> None:
|
def after_val_epoch(self,
|
||||||
|
runner,
|
||||||
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each validation epoch.
|
operations after each validation epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the validation process.
|
runner (Runner): The runner of the validation process.
|
||||||
|
metrics (Dict[str, float], optional): Evaluation results of all
|
||||||
|
metrics on validation dataset. The keys are the names of the
|
||||||
|
metrics, and the values are corresponding results.
|
||||||
"""
|
"""
|
||||||
self._after_epoch(runner, mode='val')
|
self._after_epoch(runner, mode='val')
|
||||||
|
|
||||||
def after_test_epoch(self, runner) -> None:
|
def after_test_epoch(self,
|
||||||
|
runner,
|
||||||
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each test epoch.
|
operations after each test epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the testing process.
|
runner (Runner): The runner of the testing process.
|
||||||
|
metrics (Dict[str, float], optional): Evaluation results of all
|
||||||
|
metrics on test dataset. The keys are the names of the
|
||||||
|
metrics, and the values are corresponding results.
|
||||||
"""
|
"""
|
||||||
self._after_epoch(runner, mode='test')
|
self._after_epoch(runner, mode='test')
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
from mmengine.data import BaseDataElement
|
from mmengine.data import BaseDataElement
|
||||||
from mmengine.fileio import FileClient
|
from mmengine.fileio import FileClient
|
||||||
|
@ -188,11 +188,17 @@ class LoggerHook(Hook):
|
||||||
runner, batch_idx, 'test')
|
runner, batch_idx, 'test')
|
||||||
runner.logger.info(log_str)
|
runner.logger.info(log_str)
|
||||||
|
|
||||||
def after_val_epoch(self, runner) -> None:
|
def after_val_epoch(self,
|
||||||
"""Record logs after validation epoch.
|
runner,
|
||||||
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations after each validation epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the validation process.
|
runner (Runner): The runner of the validation process.
|
||||||
|
metrics (Dict[str, float], optional): Evaluation results of all
|
||||||
|
metrics on validation dataset. The keys are the names of the
|
||||||
|
metrics, and the values are corresponding results.
|
||||||
"""
|
"""
|
||||||
tag, log_str = runner.log_processor.get_log_after_epoch(
|
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||||
runner, len(runner.val_dataloader), 'val')
|
runner, len(runner.val_dataloader), 'val')
|
||||||
|
@ -200,11 +206,17 @@ class LoggerHook(Hook):
|
||||||
runner.visualizer.add_scalars(
|
runner.visualizer.add_scalars(
|
||||||
tag, step=runner.iter, file_path=self.json_log_path)
|
tag, step=runner.iter, file_path=self.json_log_path)
|
||||||
|
|
||||||
def after_test_epoch(self, runner) -> None:
|
def after_test_epoch(self,
|
||||||
"""Record logs after testing epoch.
|
runner,
|
||||||
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations after each test epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the testing process.
|
runner (Runner): The runner of the testing process.
|
||||||
|
metrics (Dict[str, float], optional): Evaluation results of all
|
||||||
|
metrics on test dataset. The keys are the names of the
|
||||||
|
metrics, and the values are corresponding results.
|
||||||
"""
|
"""
|
||||||
_, log_str = runner.log_processor.get_log_after_epoch(
|
_, log_str = runner.log_processor.get_log_after_epoch(
|
||||||
runner, len(runner.test_dataloader), 'test')
|
runner, len(runner.test_dataloader), 'test')
|
||||||
|
|
|
@ -84,9 +84,6 @@ class OptimizerHook(Hook):
|
||||||
we keep ``outputs`` here. Defaults to None.
|
we keep ``outputs`` here. Defaults to None.
|
||||||
"""
|
"""
|
||||||
runner.optimizer.zero_grad()
|
runner.optimizer.zero_grad()
|
||||||
runner.message_hub.update_scalar(
|
|
||||||
'train/lr', runner.optimizer.param_groups[0]['lr'])
|
|
||||||
|
|
||||||
if self.detect_anomalous_params:
|
if self.detect_anomalous_params:
|
||||||
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
||||||
runner.outputs['loss'].backward()
|
runner.outputs['loss'].backward()
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Dict, Optional, Sequence
|
||||||
|
|
||||||
|
from mmengine.registry import HOOKS
|
||||||
|
from .hook import Hook
|
||||||
|
|
||||||
|
DATA_BATCH = Optional[Sequence[dict]]
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class RuntimeInfoHook(Hook):
|
||||||
|
"""A hook that updates runtime information into message hub.
|
||||||
|
|
||||||
|
E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the
|
||||||
|
training state. Components that cannot access the runner can get runtime
|
||||||
|
information through the message hub.
|
||||||
|
"""
|
||||||
|
|
||||||
|
priority = 'VERY_HIGH'
|
||||||
|
|
||||||
|
def before_run(self, runner) -> None:
|
||||||
|
"""Initialize runtime information."""
|
||||||
|
runner.message_hub.update_info('epoch', runner.epoch)
|
||||||
|
runner.message_hub.update_info('iter', runner.iter)
|
||||||
|
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
||||||
|
runner.message_hub.update_info('max_iters', runner.max_iters)
|
||||||
|
|
||||||
|
def before_train(self, runner) -> None:
|
||||||
|
"""Update resumed training state."""
|
||||||
|
runner.message_hub.update_info('epoch', runner.epoch)
|
||||||
|
runner.message_hub.update_info('iter', runner.iter)
|
||||||
|
|
||||||
|
def before_train_epoch(self, runner) -> None:
|
||||||
|
"""Update current epoch information before every epoch."""
|
||||||
|
runner.message_hub.update_info('epoch', runner.epoch)
|
||||||
|
|
||||||
|
def before_train_iter(self,
|
||||||
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: DATA_BATCH = None) -> None:
|
||||||
|
"""Update current iter and learning rate information before every
|
||||||
|
iteration."""
|
||||||
|
runner.message_hub.update_info('iter', runner.iter)
|
||||||
|
runner.message_hub.update_scalar(
|
||||||
|
'train/lr', runner.optimizer.param_groups[0]['lr'])
|
||||||
|
|
||||||
|
def after_train_iter(self,
|
||||||
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: DATA_BATCH = None,
|
||||||
|
outputs: Optional[dict] = None) -> None:
|
||||||
|
"""Update ``log_vars`` in model outputs every iteration."""
|
||||||
|
if outputs is not None:
|
||||||
|
for key, value in outputs['log_vars'].items():
|
||||||
|
runner.message_hub.update_scalar(f'train/{key}', value)
|
||||||
|
|
||||||
|
def after_val_epoch(self,
|
||||||
|
runner,
|
||||||
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations after each validation epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the validation process.
|
||||||
|
metrics (Dict[str, float], optional): Evaluation results of all
|
||||||
|
metrics on validation dataset. The keys are the names of the
|
||||||
|
metrics, and the values are corresponding results.
|
||||||
|
"""
|
||||||
|
if metrics is not None:
|
||||||
|
for key, value in metrics.items():
|
||||||
|
runner.message_hub.update_scalar(f'val/{key}', value)
|
||||||
|
|
||||||
|
def after_test_epoch(self,
|
||||||
|
runner,
|
||||||
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations after each test epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the testing process.
|
||||||
|
metrics (Dict[str, float], optional): Evaluation results of all
|
||||||
|
metrics on test dataset. The keys are the names of the
|
||||||
|
metrics, and the values are corresponding results.
|
||||||
|
"""
|
||||||
|
if metrics is not None:
|
||||||
|
for key, value in metrics.items():
|
||||||
|
runner.message_hub.update_scalar(f'test/{key}', value)
|
|
@ -80,8 +80,6 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||||
|
|
||||||
self.runner.call_hook('after_train_epoch')
|
self.runner.call_hook('after_train_epoch')
|
||||||
self._epoch += 1
|
self._epoch += 1
|
||||||
# To allow components that cannot access runner to get current epoch.
|
|
||||||
self.runner.message_hub.update_info('epoch', self._epoch)
|
|
||||||
|
|
||||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||||
"""Iterate one min-batch.
|
"""Iterate one min-batch.
|
||||||
|
@ -94,10 +92,6 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||||
# outputs should be a dict containing one or multiple loss tensors
|
# outputs should be a dict containing one or multiple loss tensors
|
||||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||||
|
|
||||||
# TODO, should move to LoggerHook
|
|
||||||
for key, value in self.runner.outputs['log_vars'].items():
|
|
||||||
self.runner.message_hub.update_scalar(f'train/{key}', value)
|
|
||||||
|
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_train_iter',
|
'after_train_iter',
|
||||||
batch_idx=idx,
|
batch_idx=idx,
|
||||||
|
@ -105,9 +99,6 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||||
outputs=self.runner.outputs)
|
outputs=self.runner.outputs)
|
||||||
|
|
||||||
self._iter += 1
|
self._iter += 1
|
||||||
# To allow components that cannot access runner to get current
|
|
||||||
# iteration.
|
|
||||||
self.runner.message_hub.update_info('iter', self._iter)
|
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
|
@ -188,19 +179,12 @@ class IterBasedTrainLoop(BaseLoop):
|
||||||
# outputs should be a dict containing loss tensor
|
# outputs should be a dict containing loss tensor
|
||||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||||
|
|
||||||
# TODO
|
|
||||||
for key, value in self.runner.outputs['log_vars'].items():
|
|
||||||
self.runner.message_hub.update_scalar(f'train/{key}', value)
|
|
||||||
|
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_train_iter',
|
'after_train_iter',
|
||||||
batch_idx=self._iter,
|
batch_idx=self._iter,
|
||||||
data_batch=data_batch,
|
data_batch=data_batch,
|
||||||
outputs=self.runner.outputs)
|
outputs=self.runner.outputs)
|
||||||
self._iter += 1
|
self._iter += 1
|
||||||
# To allow components that cannot access runner to get current
|
|
||||||
# iteration.
|
|
||||||
self.runner.message_hub.update_info('iter', self._iter)
|
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
|
@ -247,10 +231,8 @@ class ValLoop(BaseLoop):
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||||
for key, value in metrics.items():
|
|
||||||
self.runner.message_hub.update_scalar(f'val/{key}', value)
|
|
||||||
|
|
||||||
self.runner.call_hook('after_val_epoch')
|
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
||||||
self.runner.call_hook('after_val')
|
self.runner.call_hook('after_val')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -312,10 +294,8 @@ class TestLoop(BaseLoop):
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||||
for key, value in metrics.items():
|
|
||||||
self.runner.message_hub.update_scalar(f'test/{key}', value)
|
|
||||||
|
|
||||||
self.runner.call_hook('after_test_epoch')
|
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
||||||
self.runner.call_hook('after_test')
|
self.runner.call_hook('after_test')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -1353,6 +1353,8 @@ class Runner:
|
||||||
+----------------------+-------------------------+
|
+----------------------+-------------------------+
|
||||||
| Hooks | Priority |
|
| Hooks | Priority |
|
||||||
+======================+=========================+
|
+======================+=========================+
|
||||||
|
| RuntimeInfoHook | VERY_HIGH (10) |
|
||||||
|
+----------------------+-------------------------+
|
||||||
| OptimizerHook | HIGH (30) |
|
| OptimizerHook | HIGH (30) |
|
||||||
+----------------------+-------------------------+
|
+----------------------+-------------------------+
|
||||||
| IterTimerHook | NORMAL (40) |
|
| IterTimerHook | NORMAL (40) |
|
||||||
|
@ -1370,6 +1372,7 @@ class Runner:
|
||||||
default::
|
default::
|
||||||
|
|
||||||
default_hooks = dict(
|
default_hooks = dict(
|
||||||
|
runtime_info=dict(type='RuntimeInfoHook'),
|
||||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||||
timer=dict(type='IterTimerHook'),
|
timer=dict(type='IterTimerHook'),
|
||||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||||
|
@ -1392,6 +1395,7 @@ class Runner:
|
||||||
to be registered.
|
to be registered.
|
||||||
"""
|
"""
|
||||||
default_hooks: dict = dict(
|
default_hooks: dict = dict(
|
||||||
|
runtime_info=dict(type='RuntimeInfoHook'),
|
||||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||||
timer=dict(type='IterTimerHook'),
|
timer=dict(type='IterTimerHook'),
|
||||||
logger=dict(type='LoggerHook'),
|
logger=dict(type='LoggerHook'),
|
||||||
|
|
|
@ -74,12 +74,12 @@ class TestHook:
|
||||||
def test_after_val_epoch(self):
|
def test_after_val_epoch(self):
|
||||||
hook = Hook()
|
hook = Hook()
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
hook.after_val_epoch(runner)
|
hook.after_val_epoch(runner, {})
|
||||||
|
|
||||||
def test_after_test_epoch(self):
|
def test_after_test_epoch(self):
|
||||||
hook = Hook()
|
hook = Hook()
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
hook.after_test_epoch(runner)
|
hook.after_test_epoch(runner, {})
|
||||||
|
|
||||||
def test_before_train_iter(self):
|
def test_before_train_iter(self):
|
||||||
hook = Hook()
|
hook = Hook()
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest import TestCase
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from mmengine.hooks import RuntimeInfoHook
|
||||||
|
from mmengine.logging import MessageHub
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuntimeInfoHook(TestCase):
|
||||||
|
|
||||||
|
def test_before_run(self):
|
||||||
|
message_hub = MessageHub.get_instance(
|
||||||
|
'runtime_info_hook_test_before_run')
|
||||||
|
runner = Mock()
|
||||||
|
runner.epoch = 3
|
||||||
|
runner.iter = 30
|
||||||
|
runner.max_epochs = 4
|
||||||
|
runner.max_iters = 40
|
||||||
|
runner.message_hub = message_hub
|
||||||
|
hook = RuntimeInfoHook()
|
||||||
|
hook.before_run(runner)
|
||||||
|
self.assertEqual(message_hub.get_info('epoch'), 3)
|
||||||
|
self.assertEqual(message_hub.get_info('iter'), 30)
|
||||||
|
self.assertEqual(message_hub.get_info('max_epochs'), 4)
|
||||||
|
self.assertEqual(message_hub.get_info('max_iters'), 40)
|
||||||
|
|
||||||
|
def test_before_train(self):
|
||||||
|
message_hub = MessageHub.get_instance(
|
||||||
|
'runtime_info_hook_test_before_train')
|
||||||
|
runner = Mock()
|
||||||
|
runner.epoch = 7
|
||||||
|
runner.iter = 71
|
||||||
|
runner.message_hub = message_hub
|
||||||
|
hook = RuntimeInfoHook()
|
||||||
|
hook.before_train(runner)
|
||||||
|
self.assertEqual(message_hub.get_info('epoch'), 7)
|
||||||
|
self.assertEqual(message_hub.get_info('iter'), 71)
|
||||||
|
|
||||||
|
def test_before_train_epoch(self):
|
||||||
|
message_hub = MessageHub.get_instance(
|
||||||
|
'runtime_info_hook_test_before_train_epoch')
|
||||||
|
runner = Mock()
|
||||||
|
runner.epoch = 9
|
||||||
|
runner.message_hub = message_hub
|
||||||
|
hook = RuntimeInfoHook()
|
||||||
|
hook.before_train_epoch(runner)
|
||||||
|
self.assertEqual(message_hub.get_info('epoch'), 9)
|
||||||
|
|
||||||
|
def test_before_train_iter(self):
|
||||||
|
message_hub = MessageHub.get_instance(
|
||||||
|
'runtime_info_hook_test_before_train_iter')
|
||||||
|
runner = Mock()
|
||||||
|
runner.iter = 9
|
||||||
|
runner.optimizer.param_groups = [{'lr': 0.01}]
|
||||||
|
runner.message_hub = message_hub
|
||||||
|
hook = RuntimeInfoHook()
|
||||||
|
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
|
||||||
|
self.assertEqual(message_hub.get_info('iter'), 9)
|
||||||
|
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
|
||||||
|
|
||||||
|
def test_after_train_iter(self):
|
||||||
|
message_hub = MessageHub.get_instance(
|
||||||
|
'runtime_info_hook_test_after_train_iter')
|
||||||
|
runner = Mock()
|
||||||
|
runner.message_hub = message_hub
|
||||||
|
hook = RuntimeInfoHook()
|
||||||
|
hook.after_train_iter(
|
||||||
|
runner,
|
||||||
|
batch_idx=2,
|
||||||
|
data_batch=None,
|
||||||
|
outputs={'log_vars': {
|
||||||
|
'loss_cls': 1.111
|
||||||
|
}})
|
||||||
|
self.assertEqual(
|
||||||
|
message_hub.get_scalar('train/loss_cls').current(), 1.111)
|
||||||
|
|
||||||
|
def test_after_val_epoch(self):
|
||||||
|
message_hub = MessageHub.get_instance(
|
||||||
|
'runtime_info_hook_test_after_val_epoch')
|
||||||
|
runner = Mock()
|
||||||
|
runner.message_hub = message_hub
|
||||||
|
hook = RuntimeInfoHook()
|
||||||
|
hook.after_val_epoch(runner, metrics={'acc': 0.8})
|
||||||
|
self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8)
|
||||||
|
|
||||||
|
def test_after_test_epoch(self):
|
||||||
|
message_hub = MessageHub.get_instance(
|
||||||
|
'runtime_info_hook_test_after_test_epoch')
|
||||||
|
runner = Mock()
|
||||||
|
runner.message_hub = message_hub
|
||||||
|
hook = RuntimeInfoHook()
|
||||||
|
hook.after_test_epoch(runner, metrics={'acc': 0.8})
|
||||||
|
self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8)
|
|
@ -218,11 +218,14 @@ class TestRunner(TestCase):
|
||||||
test_cfg=dict(),
|
test_cfg=dict(),
|
||||||
custom_hooks=[],
|
custom_hooks=[],
|
||||||
default_hooks=dict(
|
default_hooks=dict(
|
||||||
timer=dict(type='IterTimerHook'),
|
runtime_info=dict(type='RuntimeInfoHook'),
|
||||||
checkpoint=dict(type='CheckpointHook', interval=1),
|
|
||||||
logger=dict(type='LoggerHook'),
|
|
||||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||||
param_scheduler=dict(type='ParamSchedulerHook')),
|
timer=dict(type='IterTimerHook'),
|
||||||
|
logger=dict(type='LoggerHook'),
|
||||||
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||||
|
checkpoint=dict(
|
||||||
|
type='CheckpointHook', interval=1, by_epoch=True),
|
||||||
|
sampler_seed=dict(type='DistSamplerSeedHook')),
|
||||||
launcher='none',
|
launcher='none',
|
||||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||||
)
|
)
|
||||||
|
@ -235,11 +238,13 @@ class TestRunner(TestCase):
|
||||||
num_workers=0)
|
num_workers=0)
|
||||||
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
|
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
|
||||||
self.iter_based_cfg.default_hooks = dict(
|
self.iter_based_cfg.default_hooks = dict(
|
||||||
timer=dict(type='IterTimerHook'),
|
runtime_info=dict(type='RuntimeInfoHook'),
|
||||||
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
|
||||||
logger=dict(type='LoggerHook'),
|
|
||||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||||
param_scheduler=dict(type='ParamSchedulerHook'))
|
timer=dict(type='IterTimerHook'),
|
||||||
|
logger=dict(type='LoggerHook'),
|
||||||
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||||
|
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
||||||
|
sampler_seed=dict(type='DistSamplerSeedHook'))
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.temp_dir)
|
shutil.rmtree(self.temp_dir)
|
||||||
|
@ -868,9 +873,6 @@ class TestRunner(TestCase):
|
||||||
|
|
||||||
assert isinstance(runner.train_loop, EpochBasedTrainLoop)
|
assert isinstance(runner.train_loop, EpochBasedTrainLoop)
|
||||||
|
|
||||||
assert runner.iter == runner.message_hub.get_info('iter')
|
|
||||||
assert runner.epoch == runner.message_hub.get_info('epoch')
|
|
||||||
|
|
||||||
for result, target, in zip(epoch_results, epoch_targets):
|
for result, target, in zip(epoch_results, epoch_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
for result, target, in zip(iter_results, iter_targets):
|
for result, target, in zip(iter_results, iter_targets):
|
||||||
|
@ -903,7 +905,6 @@ class TestRunner(TestCase):
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
||||||
assert runner.iter == runner.message_hub.get_info('iter')
|
|
||||||
|
|
||||||
self.assertEqual(len(epoch_results), 1)
|
self.assertEqual(len(epoch_results), 1)
|
||||||
self.assertEqual(epoch_results[0], 0)
|
self.assertEqual(epoch_results[0], 0)
|
||||||
|
@ -1035,37 +1036,37 @@ class TestRunner(TestCase):
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner._hooks = []
|
runner._hooks = []
|
||||||
|
|
||||||
# register five hooks by default
|
# register 7 hooks by default
|
||||||
runner.register_default_hooks()
|
runner.register_default_hooks()
|
||||||
self.assertEqual(len(runner._hooks), 6)
|
self.assertEqual(len(runner._hooks), 7)
|
||||||
# the third registered hook should be `DistSamplerSeedHook`
|
# the third registered hook should be `DistSamplerSeedHook`
|
||||||
self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook))
|
self.assertTrue(isinstance(runner._hooks[3], DistSamplerSeedHook))
|
||||||
# the fifth registered hook should be `ParamSchedulerHook`
|
# the fifth registered hook should be `ParamSchedulerHook`
|
||||||
self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook))
|
self.assertTrue(isinstance(runner._hooks[5], ParamSchedulerHook))
|
||||||
|
|
||||||
runner._hooks = []
|
runner._hooks = []
|
||||||
# remove `ParamSchedulerHook` from default hooks
|
# remove `ParamSchedulerHook` from default hooks
|
||||||
runner.register_default_hooks(hooks=dict(timer=None))
|
runner.register_default_hooks(hooks=dict(timer=None))
|
||||||
self.assertEqual(len(runner._hooks), 5)
|
self.assertEqual(len(runner._hooks), 6)
|
||||||
# `ParamSchedulerHook` was popped so the forth is `CheckpointHook`
|
# `ParamSchedulerHook` was popped so the fifth is `CheckpointHook`
|
||||||
self.assertTrue(isinstance(runner._hooks[4], CheckpointHook))
|
self.assertTrue(isinstance(runner._hooks[5], CheckpointHook))
|
||||||
|
|
||||||
# add a new default hook
|
# add a new default hook
|
||||||
runner._hooks = []
|
runner._hooks = []
|
||||||
runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
|
runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
|
||||||
self.assertEqual(len(runner._hooks), 7)
|
self.assertEqual(len(runner._hooks), 8)
|
||||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||||
|
|
||||||
def test_custom_hooks(self):
|
def test_custom_hooks(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_custom_hooks'
|
cfg.experiment_name = 'test_custom_hooks'
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
self.assertEqual(len(runner._hooks), 6)
|
self.assertEqual(len(runner._hooks), 7)
|
||||||
custom_hooks = [dict(type='ToyHook')]
|
custom_hooks = [dict(type='ToyHook')]
|
||||||
runner.register_custom_hooks(custom_hooks)
|
runner.register_custom_hooks(custom_hooks)
|
||||||
self.assertEqual(len(runner._hooks), 7)
|
self.assertEqual(len(runner._hooks), 8)
|
||||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||||
|
|
||||||
def test_register_hooks(self):
|
def test_register_hooks(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
@ -1075,9 +1076,9 @@ class TestRunner(TestCase):
|
||||||
runner._hooks = []
|
runner._hooks = []
|
||||||
custom_hooks = [dict(type='ToyHook')]
|
custom_hooks = [dict(type='ToyHook')]
|
||||||
runner.register_hooks(custom_hooks=custom_hooks)
|
runner.register_hooks(custom_hooks=custom_hooks)
|
||||||
# six default hooks + custom hook (ToyHook)
|
# 7 default hooks + custom hook (ToyHook)
|
||||||
self.assertEqual(len(runner._hooks), 7)
|
self.assertEqual(len(runner._hooks), 8)
|
||||||
self.assertTrue(isinstance(runner._hooks[6], ToyHook))
|
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||||
|
|
||||||
def test_custom_loop(self):
|
def test_custom_loop(self):
|
||||||
# test custom loop with additional hook
|
# test custom loop with additional hook
|
||||||
|
|
Loading…
Reference in New Issue