From fcd783fcb2a9b13255a0282b30bc73da64143a5a Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Thu, 12 Jan 2023 20:28:55 +0800 Subject: [PATCH] [Enhance] Support non-scalar type metric value. (#827) * [Enhance] Support non-scalar type metric value. * Refactor support. * Fix non-scalar tags problem during validation * Update tag processor. --- mmengine/hooks/logger_hook.py | 32 ++++++++++++++++-- mmengine/hooks/runtime_info_hook.py | 33 +++++++++++++++++-- mmengine/runner/log_processor.py | 44 +++++++++++++++++++++++-- tests/test_hooks/test_logger_hook.py | 9 ++++- tests/test_runner/test_log_processor.py | 36 ++++++++++++++++---- 5 files changed, 139 insertions(+), 15 deletions(-) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 71752be1..c14a8e7a 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -2,9 +2,13 @@ import os import os.path as osp import warnings +from collections import OrderedDict from pathlib import Path from typing import Dict, Optional, Sequence, Union +import numpy as np +import torch + from mmengine.fileio import FileClient, dump from mmengine.fileio.io import get_file_backend from mmengine.hooks import Hook @@ -252,9 +256,33 @@ class LoggerHook(Hook): metrics, and the values are corresponding results. """ tag, log_str = runner.log_processor.get_log_after_epoch( - runner, len(runner.test_dataloader), 'test') + runner, len(runner.test_dataloader), 'test', with_non_scalar=True) runner.logger.info(log_str) - dump(tag, osp.join(runner.log_dir, self.json_log_path)) # type: ignore + dump( + self._process_tags(tag), + osp.join(runner.log_dir, self.json_log_path)) # type: ignore + + @staticmethod + def _process_tags(tags: dict): + """Convert tag values to json-friendly type.""" + + def process_val(value): + if isinstance(value, (list, tuple)): + # Array type of json + return [process_val(item) for item in value] + elif isinstance(value, dict): + # Object type of json + return {k: process_val(v) for k, v in value.items()} + elif isinstance(value, (str, int, float, bool)) or value is None: + # Other supported type of json + return value + elif isinstance(value, (torch.Tensor, np.ndarray)): + return value.tolist() + # Drop unsupported values. + + processed_tags = OrderedDict(process_val(tags)) + + return processed_tags def after_run(self, runner) -> None: """Copy logs to ``self.out_dir`` if ``self.out_dir is not None`` diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 64ecafa9..0600b6ae 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch from mmengine.registry import HOOKS from mmengine.utils import get_git_hash @@ -9,6 +12,24 @@ from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] +def _is_scalar(value: Any) -> bool: + """Determine the value is a scalar type value. + + Args: + value (Any): value of log. + + Returns: + bool: whether the value is a scalar type value. + """ + if isinstance(value, np.ndarray): + return value.size == 1 + elif isinstance(value, (int, float)): + return True + elif isinstance(value, torch.Tensor): + return value.numel() == 1 + return False + + @HOOKS.register_module() class RuntimeInfoHook(Hook): """A hook that updates runtime information into message hub. @@ -112,7 +133,10 @@ class RuntimeInfoHook(Hook): """ if metrics is not None: for key, value in metrics.items(): - runner.message_hub.update_scalar(f'val/{key}', value) + if _is_scalar(value): + runner.message_hub.update_scalar(f'val/{key}', value) + else: + runner.message_hub.update_info(f'val/{key}', value) def after_test_epoch(self, runner, @@ -128,4 +152,7 @@ class RuntimeInfoHook(Hook): """ if metrics is not None: for key, value in metrics.items(): - runner.message_hub.update_scalar(f'test/{key}', value) + if _is_scalar(value): + runner.message_hub.update_scalar(f'test/{key}', value) + else: + runner.message_hub.update_info(f'test/{key}', value) diff --git a/mmengine/runner/log_processor.py b/mmengine/runner/log_processor.py index a1e408f4..41fc3bc5 100644 --- a/mmengine/runner/log_processor.py +++ b/mmengine/runner/log_processor.py @@ -2,8 +2,12 @@ import copy import datetime from collections import OrderedDict +from itertools import chain from typing import List, Optional, Tuple +import numpy as np +import torch + from mmengine.device import get_max_cuda_memory, is_cuda_available from mmengine.registry import LOG_PROCESSORS @@ -206,8 +210,11 @@ class LogProcessor: log_str += ' '.join(log_items) return tag, log_str - def get_log_after_epoch(self, runner, batch_idx: int, - mode: str) -> Tuple[dict, str]: + def get_log_after_epoch(self, + runner, + batch_idx: int, + mode: str, + with_non_scalar: bool = False) -> Tuple[dict, str]: """Format log string after validation or testing epoch. Args: @@ -215,6 +222,8 @@ class LogProcessor: batch_idx (int): The index of the current batch in the current loop. mode (str): Current mode of runner. + with_non_scalar (bool): Whether to include non-scalar infos in the + returned tag. Defaults to False. Return: Tuple(dict, str): Formatted log dict/string which will be @@ -230,6 +239,7 @@ class LogProcessor: custom_cfg_copy = self._parse_windows_size(runner, batch_idx) # tag is used to write log information to different backends. tag = self._collect_scalars(custom_cfg_copy, runner, mode) + non_scalar_tag = self._collect_non_scalars(runner, mode) tag.pop('time', None) tag.pop('data_time', None) # By epoch: @@ -252,11 +262,17 @@ class LogProcessor: # `time` and `data_time` will not be recorded in after epoch log # message. log_items = [] - for name, val in tag.items(): + for name, val in chain(tag.items(), non_scalar_tag.items()): if isinstance(val, float): val = f'{val:.{self.num_digits}f}' + if isinstance(val, (torch.Tensor, np.ndarray)): + # newline to display tensor and array. + val = f'\n{val}\n' log_items.append(f'{name}: {val}') log_str += ' '.join(log_items) + + if with_non_scalar: + tag.update(non_scalar_tag) return tag, log_str def _collect_scalars(self, custom_cfg: List[dict], runner, @@ -305,6 +321,28 @@ class LogProcessor: **log_cfg) return tag + def _collect_non_scalars(self, runner, mode: str) -> dict: + """Collect log information to compose a dict according to mode. + + Args: + runner (Runner): The runner of the training/testing/validation + process. + mode (str): Current mode of runner. + + Returns: + dict: non-scalar infos of the specified mode. + """ + # infos of train/val/test phase. + infos = runner.message_hub.runtime_info + # corresponding mode infos + mode_infos = OrderedDict() + # extract log info and remove prefix to `mode_infos` according to mode. + for prefix_key, value in infos.items(): + if prefix_key.startswith(mode): + key = prefix_key.partition('/')[-1] + mode_infos[key] = value + return mode_infos + def _check_custom_cfg(self) -> None: """Check the legality of ``self.custom_cfg``.""" diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index 3a3ddb37..5da29715 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -3,7 +3,9 @@ import os.path as osp from unittest.mock import ANY, MagicMock import pytest +import torch +from mmengine.fileio import load from mmengine.fileio.file_client import HardDiskBackend from mmengine.hooks import LoggerHook @@ -178,12 +180,17 @@ class TestLoggerHook: runner.log_dir = tmp_path runner.timestamp = 'test_after_test_epoch' runner.log_processor.get_log_after_epoch = MagicMock( - return_value=(dict(a=1, b=2), 'log_str')) + return_value=( + dict(a=1, b=2, c={'list': [1, 2]}, d=torch.tensor([1, 2, 3])), + 'log_str')) logger_hook.before_run(runner) logger_hook.after_test_epoch(runner) runner.log_processor.get_log_after_epoch.assert_called() runner.logger.info.assert_called() osp.isfile(osp.join(runner.log_dir, 'test_after_test_epoch.json')) + json_content = load( + osp.join(runner.log_dir, 'test_after_test_epoch.json')) + assert json_content == dict(a=1, b=2, c={'list': [1, 2]}, d=[1, 2, 3]) def test_after_val_iter(self): logger_hook = LoggerHook() diff --git a/tests/test_runner/test_log_processor.py b/tests/test_runner/test_log_processor.py index 2affa73f..6683e58d 100644 --- a/tests/test_runner/test_log_processor.py +++ b/tests/test_runner/test_log_processor.py @@ -144,19 +144,28 @@ class TestLogProcessor: # Prepare LoggerHook log_processor = LogProcessor(by_epoch=by_epoch) # Prepare validation information. - val_logs = dict(accuracy=0.9, data_time=1.0) - log_processor._collect_scalars = MagicMock(return_value=val_logs) + scalar_logs = dict(accuracy=0.9, data_time=1.0) + non_scalar_logs = dict( + recall={ + 'cat': 1, + 'dog': 0 + }, cm=torch.tensor([1, 2, 3])) + log_processor._collect_scalars = MagicMock(return_value=scalar_logs) + log_processor._collect_non_scalars = MagicMock( + return_value=non_scalar_logs) _, out = log_processor.get_log_after_epoch(self.runner, 2, mode) + expect_metric_str = ("accuracy: 0.9000 recall: {'cat': 1, 'dog': 0} " + 'cm: \ntensor([1, 2, 3])\n') if by_epoch: if mode == 'test': - assert out == 'Epoch(test) [5/5] accuracy: 0.9000' + assert out == 'Epoch(test) [5/5] ' + expect_metric_str else: - assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000' + assert out == 'Epoch(val) [1][10/10] ' + expect_metric_str else: if mode == 'test': - assert out == 'Iter(test) [5/5] accuracy: 0.9000' + assert out == 'Iter(test) [5/5] ' + expect_metric_str else: - assert out == 'Iter(val) [10/10] accuracy: 0.9000' + assert out == 'Iter(val) [10/10] ' + expect_metric_str def test_collect_scalars(self): history_count = np.ones(100) @@ -196,6 +205,21 @@ class TestLogProcessor: assert list(tag.keys()) == ['metric'] assert tag['metric'] == metric_scalars[-1] + def test_collect_non_scalars(self): + metric1 = np.random.rand(10) + metric2 = torch.tensor(10) + + log_processor = LogProcessor() + # Collect with prefix. + log_infos = {'test/metric1': metric1, 'test/metric2': metric2} + self.runner.message_hub._runtime_info = log_infos + tag = log_processor._collect_non_scalars(self.runner, mode='test') + # Test training key in tag. + assert list(tag.keys()) == ['metric1', 'metric2'] + # Test statistics lr with `current`, loss and time with 'mean' + assert tag['metric1'] is metric1 + assert tag['metric2'] is metric2 + @patch('torch.cuda.max_memory_allocated', MagicMock()) @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) def test_get_max_memory(self):