mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* [Enhance] Support non-scalar type metric value. * Refactor support. * Fix non-scalar tags problem during validation * Update tag processor.
159 lines
6.0 KiB
Python
159 lines
6.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmengine.registry import HOOKS
|
|
from mmengine.utils import get_git_hash
|
|
from mmengine.version import __version__
|
|
from .hook import Hook
|
|
|
|
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
|
|
|
|
|
def _is_scalar(value: Any) -> bool:
|
|
"""Determine the value is a scalar type value.
|
|
|
|
Args:
|
|
value (Any): value of log.
|
|
|
|
Returns:
|
|
bool: whether the value is a scalar type value.
|
|
"""
|
|
if isinstance(value, np.ndarray):
|
|
return value.size == 1
|
|
elif isinstance(value, (int, float)):
|
|
return True
|
|
elif isinstance(value, torch.Tensor):
|
|
return value.numel() == 1
|
|
return False
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class RuntimeInfoHook(Hook):
|
|
"""A hook that updates runtime information into message hub.
|
|
|
|
E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the
|
|
training state. Components that cannot access the runner can get runtime
|
|
information through the message hub.
|
|
"""
|
|
|
|
priority = 'VERY_HIGH'
|
|
|
|
def before_run(self, runner) -> None:
|
|
"""Update metainfo.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training process.
|
|
"""
|
|
metainfo = dict(
|
|
cfg=runner.cfg.pretty_text,
|
|
seed=runner.seed,
|
|
experiment_name=runner.experiment_name,
|
|
mmengine_version=__version__ + get_git_hash())
|
|
runner.message_hub.update_info_dict(metainfo)
|
|
|
|
def before_train(self, runner) -> None:
|
|
"""Update resumed training state.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training process.
|
|
"""
|
|
runner.message_hub.update_info('epoch', runner.epoch)
|
|
runner.message_hub.update_info('iter', runner.iter)
|
|
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
|
runner.message_hub.update_info('max_iters', runner.max_iters)
|
|
if hasattr(runner.train_dataloader.dataset, 'metainfo'):
|
|
runner.message_hub.update_info(
|
|
'dataset_meta', runner.train_dataloader.dataset.metainfo)
|
|
|
|
def before_train_epoch(self, runner) -> None:
|
|
"""Update current epoch information before every epoch.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training process.
|
|
"""
|
|
runner.message_hub.update_info('epoch', runner.epoch)
|
|
|
|
def before_train_iter(self,
|
|
runner,
|
|
batch_idx: int,
|
|
data_batch: DATA_BATCH = None) -> None:
|
|
"""Update current iter and learning rate information before every
|
|
iteration.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training process.
|
|
batch_idx (int): The index of the current batch in the train loop.
|
|
data_batch (Sequence[dict], optional): Data from dataloader.
|
|
Defaults to None.
|
|
"""
|
|
runner.message_hub.update_info('iter', runner.iter)
|
|
lr_dict = runner.optim_wrapper.get_lr()
|
|
assert isinstance(lr_dict, dict), (
|
|
'`runner.optim_wrapper.get_lr()` should return a dict '
|
|
'of learning rate when training with OptimWrapper(single '
|
|
'optimizer) or OptimWrapperDict(multiple optimizer), '
|
|
f'but got {type(lr_dict)} please check your optimizer '
|
|
'constructor return an `OptimWrapper` or `OptimWrapperDict` '
|
|
'instance')
|
|
for name, lr in lr_dict.items():
|
|
runner.message_hub.update_scalar(f'train/{name}', lr[0])
|
|
|
|
def after_train_iter(self,
|
|
runner,
|
|
batch_idx: int,
|
|
data_batch: DATA_BATCH = None,
|
|
outputs: Optional[dict] = None) -> None:
|
|
"""Update ``log_vars`` in model outputs every iteration.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training process.
|
|
batch_idx (int): The index of the current batch in the train loop.
|
|
data_batch (Sequence[dict], optional): Data from dataloader.
|
|
Defaults to None.
|
|
outputs (dict, optional): Outputs from model. Defaults to None.
|
|
"""
|
|
if outputs is not None:
|
|
for key, value in outputs.items():
|
|
runner.message_hub.update_scalar(f'train/{key}', value)
|
|
|
|
def after_val_epoch(self,
|
|
runner,
|
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
|
"""All subclasses should override this method, if they need any
|
|
operations after each validation epoch.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the validation process.
|
|
metrics (Dict[str, float], optional): Evaluation results of all
|
|
metrics on validation dataset. The keys are the names of the
|
|
metrics, and the values are corresponding results.
|
|
"""
|
|
if metrics is not None:
|
|
for key, value in metrics.items():
|
|
if _is_scalar(value):
|
|
runner.message_hub.update_scalar(f'val/{key}', value)
|
|
else:
|
|
runner.message_hub.update_info(f'val/{key}', value)
|
|
|
|
def after_test_epoch(self,
|
|
runner,
|
|
metrics: Optional[Dict[str, float]] = None) -> None:
|
|
"""All subclasses should override this method, if they need any
|
|
operations after each test epoch.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the testing process.
|
|
metrics (Dict[str, float], optional): Evaluation results of all
|
|
metrics on test dataset. The keys are the names of the
|
|
metrics, and the values are corresponding results.
|
|
"""
|
|
if metrics is not None:
|
|
for key, value in metrics.items():
|
|
if _is_scalar(value):
|
|
runner.message_hub.update_scalar(f'test/{key}', value)
|
|
else:
|
|
runner.message_hub.update_info(f'test/{key}', value)
|