[Enhance] Support non-scalar type metric value. (#827)
* [Enhance] Support non-scalar type metric value. * Refactor support. * Fix non-scalar tags problem during validation * Update tag processor.pull/875/head
parent
79067e4628
commit
fcd783fcb2
|
@ -2,9 +2,13 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.fileio import FileClient, dump
|
||||
from mmengine.fileio.io import get_file_backend
|
||||
from mmengine.hooks import Hook
|
||||
|
@ -252,9 +256,33 @@ class LoggerHook(Hook):
|
|||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||
runner, len(runner.test_dataloader), 'test')
|
||||
runner, len(runner.test_dataloader), 'test', with_non_scalar=True)
|
||||
runner.logger.info(log_str)
|
||||
dump(tag, osp.join(runner.log_dir, self.json_log_path)) # type: ignore
|
||||
dump(
|
||||
self._process_tags(tag),
|
||||
osp.join(runner.log_dir, self.json_log_path)) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _process_tags(tags: dict):
|
||||
"""Convert tag values to json-friendly type."""
|
||||
|
||||
def process_val(value):
|
||||
if isinstance(value, (list, tuple)):
|
||||
# Array type of json
|
||||
return [process_val(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
# Object type of json
|
||||
return {k: process_val(v) for k, v in value.items()}
|
||||
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||
# Other supported type of json
|
||||
return value
|
||||
elif isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
return value.tolist()
|
||||
# Drop unsupported values.
|
||||
|
||||
processed_tags = OrderedDict(process_val(tags))
|
||||
|
||||
return processed_tags
|
||||
|
||||
def after_run(self, runner) -> None:
|
||||
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import get_git_hash
|
||||
|
@ -9,6 +12,24 @@ from .hook import Hook
|
|||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
def _is_scalar(value: Any) -> bool:
|
||||
"""Determine the value is a scalar type value.
|
||||
|
||||
Args:
|
||||
value (Any): value of log.
|
||||
|
||||
Returns:
|
||||
bool: whether the value is a scalar type value.
|
||||
"""
|
||||
if isinstance(value, np.ndarray):
|
||||
return value.size == 1
|
||||
elif isinstance(value, (int, float)):
|
||||
return True
|
||||
elif isinstance(value, torch.Tensor):
|
||||
return value.numel() == 1
|
||||
return False
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class RuntimeInfoHook(Hook):
|
||||
"""A hook that updates runtime information into message hub.
|
||||
|
@ -112,7 +133,10 @@ class RuntimeInfoHook(Hook):
|
|||
"""
|
||||
if metrics is not None:
|
||||
for key, value in metrics.items():
|
||||
runner.message_hub.update_scalar(f'val/{key}', value)
|
||||
if _is_scalar(value):
|
||||
runner.message_hub.update_scalar(f'val/{key}', value)
|
||||
else:
|
||||
runner.message_hub.update_info(f'val/{key}', value)
|
||||
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
|
@ -128,4 +152,7 @@ class RuntimeInfoHook(Hook):
|
|||
"""
|
||||
if metrics is not None:
|
||||
for key, value in metrics.items():
|
||||
runner.message_hub.update_scalar(f'test/{key}', value)
|
||||
if _is_scalar(value):
|
||||
runner.message_hub.update_scalar(f'test/{key}', value)
|
||||
else:
|
||||
runner.message_hub.update_info(f'test/{key}', value)
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
import copy
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.device import get_max_cuda_memory, is_cuda_available
|
||||
from mmengine.registry import LOG_PROCESSORS
|
||||
|
||||
|
@ -206,8 +210,11 @@ class LogProcessor:
|
|||
log_str += ' '.join(log_items)
|
||||
return tag, log_str
|
||||
|
||||
def get_log_after_epoch(self, runner, batch_idx: int,
|
||||
mode: str) -> Tuple[dict, str]:
|
||||
def get_log_after_epoch(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
mode: str,
|
||||
with_non_scalar: bool = False) -> Tuple[dict, str]:
|
||||
"""Format log string after validation or testing epoch.
|
||||
|
||||
Args:
|
||||
|
@ -215,6 +222,8 @@ class LogProcessor:
|
|||
batch_idx (int): The index of the current batch in the current
|
||||
loop.
|
||||
mode (str): Current mode of runner.
|
||||
with_non_scalar (bool): Whether to include non-scalar infos in the
|
||||
returned tag. Defaults to False.
|
||||
|
||||
Return:
|
||||
Tuple(dict, str): Formatted log dict/string which will be
|
||||
|
@ -230,6 +239,7 @@ class LogProcessor:
|
|||
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
|
||||
# tag is used to write log information to different backends.
|
||||
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
|
||||
non_scalar_tag = self._collect_non_scalars(runner, mode)
|
||||
tag.pop('time', None)
|
||||
tag.pop('data_time', None)
|
||||
# By epoch:
|
||||
|
@ -252,11 +262,17 @@ class LogProcessor:
|
|||
# `time` and `data_time` will not be recorded in after epoch log
|
||||
# message.
|
||||
log_items = []
|
||||
for name, val in tag.items():
|
||||
for name, val in chain(tag.items(), non_scalar_tag.items()):
|
||||
if isinstance(val, float):
|
||||
val = f'{val:.{self.num_digits}f}'
|
||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||
# newline to display tensor and array.
|
||||
val = f'\n{val}\n'
|
||||
log_items.append(f'{name}: {val}')
|
||||
log_str += ' '.join(log_items)
|
||||
|
||||
if with_non_scalar:
|
||||
tag.update(non_scalar_tag)
|
||||
return tag, log_str
|
||||
|
||||
def _collect_scalars(self, custom_cfg: List[dict], runner,
|
||||
|
@ -305,6 +321,28 @@ class LogProcessor:
|
|||
**log_cfg)
|
||||
return tag
|
||||
|
||||
def _collect_non_scalars(self, runner, mode: str) -> dict:
|
||||
"""Collect log information to compose a dict according to mode.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training/testing/validation
|
||||
process.
|
||||
mode (str): Current mode of runner.
|
||||
|
||||
Returns:
|
||||
dict: non-scalar infos of the specified mode.
|
||||
"""
|
||||
# infos of train/val/test phase.
|
||||
infos = runner.message_hub.runtime_info
|
||||
# corresponding mode infos
|
||||
mode_infos = OrderedDict()
|
||||
# extract log info and remove prefix to `mode_infos` according to mode.
|
||||
for prefix_key, value in infos.items():
|
||||
if prefix_key.startswith(mode):
|
||||
key = prefix_key.partition('/')[-1]
|
||||
mode_infos[key] = value
|
||||
return mode_infos
|
||||
|
||||
def _check_custom_cfg(self) -> None:
|
||||
"""Check the legality of ``self.custom_cfg``."""
|
||||
|
||||
|
|
|
@ -3,7 +3,9 @@ import os.path as osp
|
|||
from unittest.mock import ANY, MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.fileio import load
|
||||
from mmengine.fileio.file_client import HardDiskBackend
|
||||
from mmengine.hooks import LoggerHook
|
||||
|
||||
|
@ -178,12 +180,17 @@ class TestLoggerHook:
|
|||
runner.log_dir = tmp_path
|
||||
runner.timestamp = 'test_after_test_epoch'
|
||||
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||
return_value=(dict(a=1, b=2), 'log_str'))
|
||||
return_value=(
|
||||
dict(a=1, b=2, c={'list': [1, 2]}, d=torch.tensor([1, 2, 3])),
|
||||
'log_str'))
|
||||
logger_hook.before_run(runner)
|
||||
logger_hook.after_test_epoch(runner)
|
||||
runner.log_processor.get_log_after_epoch.assert_called()
|
||||
runner.logger.info.assert_called()
|
||||
osp.isfile(osp.join(runner.log_dir, 'test_after_test_epoch.json'))
|
||||
json_content = load(
|
||||
osp.join(runner.log_dir, 'test_after_test_epoch.json'))
|
||||
assert json_content == dict(a=1, b=2, c={'list': [1, 2]}, d=[1, 2, 3])
|
||||
|
||||
def test_after_val_iter(self):
|
||||
logger_hook = LoggerHook()
|
||||
|
|
|
@ -144,19 +144,28 @@ class TestLogProcessor:
|
|||
# Prepare LoggerHook
|
||||
log_processor = LogProcessor(by_epoch=by_epoch)
|
||||
# Prepare validation information.
|
||||
val_logs = dict(accuracy=0.9, data_time=1.0)
|
||||
log_processor._collect_scalars = MagicMock(return_value=val_logs)
|
||||
scalar_logs = dict(accuracy=0.9, data_time=1.0)
|
||||
non_scalar_logs = dict(
|
||||
recall={
|
||||
'cat': 1,
|
||||
'dog': 0
|
||||
}, cm=torch.tensor([1, 2, 3]))
|
||||
log_processor._collect_scalars = MagicMock(return_value=scalar_logs)
|
||||
log_processor._collect_non_scalars = MagicMock(
|
||||
return_value=non_scalar_logs)
|
||||
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
|
||||
expect_metric_str = ("accuracy: 0.9000 recall: {'cat': 1, 'dog': 0} "
|
||||
'cm: \ntensor([1, 2, 3])\n')
|
||||
if by_epoch:
|
||||
if mode == 'test':
|
||||
assert out == 'Epoch(test) [5/5] accuracy: 0.9000'
|
||||
assert out == 'Epoch(test) [5/5] ' + expect_metric_str
|
||||
else:
|
||||
assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000'
|
||||
assert out == 'Epoch(val) [1][10/10] ' + expect_metric_str
|
||||
else:
|
||||
if mode == 'test':
|
||||
assert out == 'Iter(test) [5/5] accuracy: 0.9000'
|
||||
assert out == 'Iter(test) [5/5] ' + expect_metric_str
|
||||
else:
|
||||
assert out == 'Iter(val) [10/10] accuracy: 0.9000'
|
||||
assert out == 'Iter(val) [10/10] ' + expect_metric_str
|
||||
|
||||
def test_collect_scalars(self):
|
||||
history_count = np.ones(100)
|
||||
|
@ -196,6 +205,21 @@ class TestLogProcessor:
|
|||
assert list(tag.keys()) == ['metric']
|
||||
assert tag['metric'] == metric_scalars[-1]
|
||||
|
||||
def test_collect_non_scalars(self):
|
||||
metric1 = np.random.rand(10)
|
||||
metric2 = torch.tensor(10)
|
||||
|
||||
log_processor = LogProcessor()
|
||||
# Collect with prefix.
|
||||
log_infos = {'test/metric1': metric1, 'test/metric2': metric2}
|
||||
self.runner.message_hub._runtime_info = log_infos
|
||||
tag = log_processor._collect_non_scalars(self.runner, mode='test')
|
||||
# Test training key in tag.
|
||||
assert list(tag.keys()) == ['metric1', 'metric2']
|
||||
# Test statistics lr with `current`, loss and time with 'mean'
|
||||
assert tag['metric1'] is metric1
|
||||
assert tag['metric2'] is metric2
|
||||
|
||||
@patch('torch.cuda.max_memory_allocated', MagicMock())
|
||||
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
|
||||
def test_get_max_memory(self):
|
||||
|
|
Loading…
Reference in New Issue