[Fix] KeyError is thrown in _collect_scalars when log_with_hierarchy is True (#1085)
* Fix log processor * Fix custom keypull/1098/head
parent
a7d4b7c742
commit
be347df770
|
@ -341,6 +341,7 @@ class LogProcessor:
|
|||
Returns:
|
||||
dict: Statistical values of logs.
|
||||
"""
|
||||
custom_cfg = copy.deepcopy(custom_cfg)
|
||||
tag = OrderedDict()
|
||||
# history_scalars of train/val/test phase.
|
||||
history_scalars = runner.message_hub.log_scalars
|
||||
|
|
|
@ -5,12 +5,14 @@ from unittest.mock import MagicMock, patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
|
||||
from mmengine.runner import LogProcessor
|
||||
from mmengine.testing import RunnerTestCase
|
||||
|
||||
|
||||
class TestLogProcessor:
|
||||
class TestLogProcessor(RunnerTestCase):
|
||||
|
||||
def test_init(self):
|
||||
log_processor = LogProcessor(
|
||||
|
@ -69,12 +71,13 @@ class TestLogProcessor:
|
|||
with pytest.raises(TypeError):
|
||||
log_processor._parse_windows_size(self.runner, 1, custom_cfg)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'by_epoch,mode,log_with_hierarchy',
|
||||
# yapf: disable
|
||||
@parameterized.expand(
|
||||
([True, 'train', True], [True, 'train', False], [False, 'train', True],
|
||||
[False, 'train', False], [True, 'val', True], [True, 'val', False],
|
||||
[False, 'val', True], [False, 'val', False], [True, 'test', True],
|
||||
[True, 'test', False], [False, 'test', True], [False, 'test', False]))
|
||||
# yapf: enable
|
||||
def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
|
||||
# Prepare LoggerHook
|
||||
log_processor = LogProcessor(
|
||||
|
@ -141,8 +144,7 @@ class TestLogProcessor:
|
|||
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
|
||||
assert out == log_str
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'by_epoch,mode,log_with_hierarchy',
|
||||
@parameterized.expand(
|
||||
([True, 'val', True], [True, 'val', False], [False, 'val', True],
|
||||
[False, 'val', False], [True, 'test', True], [False, 'test', False]))
|
||||
def test_log_val(self, by_epoch, mode, log_with_hierarchy):
|
||||
|
@ -188,7 +190,7 @@ class TestLogProcessor:
|
|||
custom_cfg = [
|
||||
dict(data_src='time', method_name='max', log_name='time_max')
|
||||
]
|
||||
logger_hook = LogProcessor(custom_cfg=custom_cfg)
|
||||
log_processor = LogProcessor(custom_cfg=custom_cfg)
|
||||
# Collect with prefix.
|
||||
log_scalars = {
|
||||
'train/time': history_time_buffer,
|
||||
|
@ -197,7 +199,7 @@ class TestLogProcessor:
|
|||
'val/metric': history_metric_buffer
|
||||
}
|
||||
self.runner.message_hub._log_scalars = log_scalars
|
||||
tag = logger_hook._collect_scalars(
|
||||
tag = log_processor._collect_scalars(
|
||||
copy.deepcopy(custom_cfg), self.runner, mode='train')
|
||||
# Test training key in tag.
|
||||
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
|
||||
|
@ -206,7 +208,7 @@ class TestLogProcessor:
|
|||
assert tag['time_max'] == time_scalars.max()
|
||||
assert tag['loss_cls'] == loss_cls_scalars[-10:].mean()
|
||||
|
||||
tag = logger_hook._collect_scalars(
|
||||
tag = log_processor._collect_scalars(
|
||||
copy.deepcopy(custom_cfg), self.runner, mode='val')
|
||||
assert list(tag.keys()) == ['metric']
|
||||
assert tag['metric'] == metric_scalars[-1]
|
||||
|
@ -268,7 +270,7 @@ class TestLogProcessor:
|
|||
loop = log_processor._get_cur_loop(self.runner, 'test')
|
||||
assert len(loop.dataloader) == 5
|
||||
|
||||
def setup_method(self):
|
||||
def setUp(self):
|
||||
runner = MagicMock()
|
||||
runner.epoch = 1
|
||||
runner.max_epochs = 10
|
||||
|
@ -289,3 +291,36 @@ class TestLogProcessor:
|
|||
message_hub.update_scalar('val/acc', i * 0.1)
|
||||
runner.message_hub = message_hub
|
||||
self.runner = runner
|
||||
super().setUp()
|
||||
|
||||
def test_with_runner(self):
|
||||
cfg = self.epoch_based_cfg.copy()
|
||||
cfg.log_processor = dict(
|
||||
custom_cfg=[
|
||||
dict(
|
||||
data_src='time',
|
||||
window_size='epoch',
|
||||
log_name='iter_time',
|
||||
method_name='mean')
|
||||
],
|
||||
log_with_hierarchy=True)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
runner.val()
|
||||
runner.test()
|
||||
|
||||
cfg = self.iter_based_cfg.copy()
|
||||
cfg.log_processor = dict(
|
||||
by_epoch=False,
|
||||
custom_cfg=[
|
||||
dict(
|
||||
data_src='time',
|
||||
window_size=100,
|
||||
log_name='iter_time',
|
||||
method_name='mean')
|
||||
],
|
||||
log_with_hierarchy=True)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
runner.val()
|
||||
runner.test()
|
||||
|
|
Loading…
Reference in New Issue