[Fix] KeyError is thrown in _collect_scalars when log_with_hierarchy is True (#1085)

* Fix log processor

* Fix custom key
pull/1098/head
Mashiro 2023-04-20 10:52:32 +08:00 committed by GitHub
parent a7d4b7c742
commit be347df770
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 9 deletions

View File

@ -341,6 +341,7 @@ class LogProcessor:
Returns:
dict: Statistical values of logs.
"""
custom_cfg = copy.deepcopy(custom_cfg)
tag = OrderedDict()
# history_scalars of train/val/test phase.
history_scalars = runner.message_hub.log_scalars

View File

@ -5,12 +5,14 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
from parameterized import parameterized
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
from mmengine.runner import LogProcessor
from mmengine.testing import RunnerTestCase
class TestLogProcessor:
class TestLogProcessor(RunnerTestCase):
def test_init(self):
log_processor = LogProcessor(
@ -69,12 +71,13 @@ class TestLogProcessor:
with pytest.raises(TypeError):
log_processor._parse_windows_size(self.runner, 1, custom_cfg)
@pytest.mark.parametrize(
'by_epoch,mode,log_with_hierarchy',
# yapf: disable
@parameterized.expand(
([True, 'train', True], [True, 'train', False], [False, 'train', True],
[False, 'train', False], [True, 'val', True], [True, 'val', False],
[False, 'val', True], [False, 'val', False], [True, 'test', True],
[True, 'test', False], [False, 'test', True], [False, 'test', False]))
# yapf: enable
def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
# Prepare LoggerHook
log_processor = LogProcessor(
@ -141,8 +144,7 @@ class TestLogProcessor:
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
assert out == log_str
@pytest.mark.parametrize(
'by_epoch,mode,log_with_hierarchy',
@parameterized.expand(
([True, 'val', True], [True, 'val', False], [False, 'val', True],
[False, 'val', False], [True, 'test', True], [False, 'test', False]))
def test_log_val(self, by_epoch, mode, log_with_hierarchy):
@ -188,7 +190,7 @@ class TestLogProcessor:
custom_cfg = [
dict(data_src='time', method_name='max', log_name='time_max')
]
logger_hook = LogProcessor(custom_cfg=custom_cfg)
log_processor = LogProcessor(custom_cfg=custom_cfg)
# Collect with prefix.
log_scalars = {
'train/time': history_time_buffer,
@ -197,7 +199,7 @@ class TestLogProcessor:
'val/metric': history_metric_buffer
}
self.runner.message_hub._log_scalars = log_scalars
tag = logger_hook._collect_scalars(
tag = log_processor._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='train')
# Test training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
@ -206,7 +208,7 @@ class TestLogProcessor:
assert tag['time_max'] == time_scalars.max()
assert tag['loss_cls'] == loss_cls_scalars[-10:].mean()
tag = logger_hook._collect_scalars(
tag = log_processor._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='val')
assert list(tag.keys()) == ['metric']
assert tag['metric'] == metric_scalars[-1]
@ -268,7 +270,7 @@ class TestLogProcessor:
loop = log_processor._get_cur_loop(self.runner, 'test')
assert len(loop.dataloader) == 5
def setup_method(self):
def setUp(self):
runner = MagicMock()
runner.epoch = 1
runner.max_epochs = 10
@ -289,3 +291,36 @@ class TestLogProcessor:
message_hub.update_scalar('val/acc', i * 0.1)
runner.message_hub = message_hub
self.runner = runner
super().setUp()
def test_with_runner(self):
cfg = self.epoch_based_cfg.copy()
cfg.log_processor = dict(
custom_cfg=[
dict(
data_src='time',
window_size='epoch',
log_name='iter_time',
method_name='mean')
],
log_with_hierarchy=True)
runner = self.build_runner(cfg)
runner.train()
runner.val()
runner.test()
cfg = self.iter_based_cfg.copy()
cfg.log_processor = dict(
by_epoch=False,
custom_cfg=[
dict(
data_src='time',
window_size=100,
log_name='iter_time',
method_name='mean')
],
log_with_hierarchy=True)
runner = self.build_runner(cfg)
runner.train()
runner.val()
runner.test()