mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Support writing data to vis_backend
with prefix (#972)
* Log with prefix * Fix test of loggerhook * minor refine * minor refine * Fix unit test * clean the code * deepcopy in method * replace regex * Fix as comment * Enhance readable * rename reserve_prefix to remove_prefix * Fix as comment * Refine unit test * Adjust sequence * clean the code * clean the code * revert renaming reserve prefix * Count the dataloader length in _get_dataloader_size
This commit is contained in:
parent
0d25625ba2
commit
8063d2cce7
@ -234,14 +234,8 @@ class LoggerHook(Hook):
|
|||||||
runner, len(runner.val_dataloader), 'val')
|
runner, len(runner.val_dataloader), 'val')
|
||||||
runner.logger.info(log_str)
|
runner.logger.info(log_str)
|
||||||
if self.log_metric_by_epoch:
|
if self.log_metric_by_epoch:
|
||||||
# when `log_metric_by_epoch` is set to True, it's expected
|
|
||||||
# that validation metric can be logged by epoch rather than
|
|
||||||
# by iter. At the same time, scalars related to time should
|
|
||||||
# still be logged by iter to avoid messy visualized result.
|
|
||||||
# see details in PR #278.
|
|
||||||
metric_tags = {k: v for k, v in tag.items() if 'time' not in k}
|
|
||||||
runner.visualizer.add_scalars(
|
runner.visualizer.add_scalars(
|
||||||
metric_tags, step=runner.epoch, file_path=self.json_log_path)
|
tag, step=runner.epoch, file_path=self.json_log_path)
|
||||||
else:
|
else:
|
||||||
runner.visualizer.add_scalars(
|
runner.visualizer.add_scalars(
|
||||||
tag, step=runner.iter, file_path=self.json_log_path)
|
tag, step=runner.iter, file_path=self.json_log_path)
|
||||||
|
@ -52,6 +52,13 @@ class LogProcessor:
|
|||||||
`epoch` to statistics log value by epoch.
|
`epoch` to statistics log value by epoch.
|
||||||
num_digits (int): The number of significant digit shown in the
|
num_digits (int): The number of significant digit shown in the
|
||||||
logging message.
|
logging message.
|
||||||
|
log_with_hierarchy (bool): Whether to log with hierarchy. If it is
|
||||||
|
True, the information is written to visualizer backend such as
|
||||||
|
:obj:`LocalVisBackend` and :obj:`TensorboardBackend`
|
||||||
|
with hierarchy. For example, ``loss`` will be saved as
|
||||||
|
``train/loss``, and accuracy will be saved as ``val/accuracy``.
|
||||||
|
Defaults to False.
|
||||||
|
`New in version 0.7.0.`
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
||||||
@ -98,11 +105,13 @@ class LogProcessor:
|
|||||||
window_size=10,
|
window_size=10,
|
||||||
by_epoch=True,
|
by_epoch=True,
|
||||||
custom_cfg: Optional[List[dict]] = None,
|
custom_cfg: Optional[List[dict]] = None,
|
||||||
num_digits: int = 4):
|
num_digits: int = 4,
|
||||||
|
log_with_hierarchy: bool = False):
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.by_epoch = by_epoch
|
self.by_epoch = by_epoch
|
||||||
self.custom_cfg = custom_cfg if custom_cfg else []
|
self.custom_cfg = custom_cfg if custom_cfg else []
|
||||||
self.num_digits = num_digits
|
self.num_digits = num_digits
|
||||||
|
self.log_with_hierarchy = log_with_hierarchy
|
||||||
self._check_custom_cfg()
|
self._check_custom_cfg()
|
||||||
|
|
||||||
def get_log_after_iter(self, runner, batch_idx: int,
|
def get_log_after_iter(self, runner, batch_idx: int,
|
||||||
@ -120,18 +129,26 @@ class LogProcessor:
|
|||||||
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
||||||
"""
|
"""
|
||||||
assert mode in ['train', 'test', 'val']
|
assert mode in ['train', 'test', 'val']
|
||||||
current_loop = self._get_cur_loop(runner, mode)
|
|
||||||
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
|
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
|
||||||
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
|
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
|
||||||
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
|
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
||||||
# tag is used to write log information to different backends.
|
self.custom_cfg)
|
||||||
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
|
# log_tag is used to write log information to terminal
|
||||||
# `log_tag` will pop 'lr' and loop other keys to `log_str`.
|
# If `self.log_with_hierarchy` is False, the tag is the same as
|
||||||
log_tag = copy.deepcopy(tag)
|
# log_tag. Otherwise, each key in tag starts with prefix `train`,
|
||||||
|
# `test` or `val`
|
||||||
|
log_tag = self._collect_scalars(parsed_cfg, runner, mode)
|
||||||
|
|
||||||
|
if not self.log_with_hierarchy:
|
||||||
|
tag = copy.deepcopy(log_tag)
|
||||||
|
else:
|
||||||
|
tag = self._collect_scalars(parsed_cfg, runner, mode, True)
|
||||||
|
|
||||||
# Record learning rate.
|
# Record learning rate.
|
||||||
lr_str_list = []
|
lr_str_list = []
|
||||||
for key, value in tag.items():
|
for key, value in tag.items():
|
||||||
if key.endswith('lr'):
|
if key.endswith('lr'):
|
||||||
|
key = self._remove_prefix(key, f'{mode}/')
|
||||||
log_tag.pop(key)
|
log_tag.pop(key)
|
||||||
lr_str_list.append(f'{key}: '
|
lr_str_list.append(f'{key}: '
|
||||||
f'{value:.{self.num_digits}e}')
|
f'{value:.{self.num_digits}e}')
|
||||||
@ -148,7 +165,7 @@ class LogProcessor:
|
|||||||
# Epoch(train) [ 9][010/270]
|
# Epoch(train) [ 9][010/270]
|
||||||
# ... ||| |||
|
# ... ||| |||
|
||||||
# Epoch(train) [ 10][100/270]
|
# Epoch(train) [ 10][100/270]
|
||||||
dataloader_len = len(current_loop.dataloader)
|
dataloader_len = self._get_dataloader_size(runner, mode)
|
||||||
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))
|
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))
|
||||||
|
|
||||||
if mode in ['train', 'val']:
|
if mode in ['train', 'val']:
|
||||||
@ -174,23 +191,22 @@ class LogProcessor:
|
|||||||
log_str = (f'Iter({mode}) '
|
log_str = (f'Iter({mode}) '
|
||||||
f'[{cur_iter_str}/{runner.max_iters}] ')
|
f'[{cur_iter_str}/{runner.max_iters}] ')
|
||||||
else:
|
else:
|
||||||
dataloader_len = len(current_loop.dataloader)
|
dataloader_len = self._get_dataloader_size(runner, mode)
|
||||||
cur_iter_str = str(batch_idx + 1).rjust(
|
cur_iter_str = str(batch_idx + 1).rjust(
|
||||||
len(str(dataloader_len)))
|
len(str(dataloader_len)))
|
||||||
log_str = (f'Iter({mode}) [{cur_iter_str}'
|
log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ')
|
||||||
f'/{len(current_loop.dataloader)}] ')
|
|
||||||
# Concatenate lr, momentum string with log header.
|
# Concatenate lr, momentum string with log header.
|
||||||
log_str += f'{lr_str} '
|
log_str += f'{lr_str} '
|
||||||
# If IterTimerHook used in runner, eta, time, and data_time should be
|
# If IterTimerHook used in runner, eta, time, and data_time should be
|
||||||
# recorded.
|
# recorded.
|
||||||
if (all(item in tag for item in ['time', 'data_time'])
|
if (all(item in log_tag for item in ['time', 'data_time'])
|
||||||
and 'eta' in runner.message_hub.runtime_info):
|
and 'eta' in runner.message_hub.runtime_info):
|
||||||
eta = runner.message_hub.get_info('eta')
|
eta = runner.message_hub.get_info('eta')
|
||||||
eta_str = str(datetime.timedelta(seconds=int(eta)))
|
eta_str = str(datetime.timedelta(seconds=int(eta)))
|
||||||
log_str += f'eta: {eta_str} '
|
log_str += f'eta: {eta_str} '
|
||||||
log_str += (f'time: {tag["time"]:.{self.num_digits}f} '
|
log_str += (f'time: {log_tag["time"]:.{self.num_digits}f} '
|
||||||
f'data_time: '
|
f'data_time: '
|
||||||
f'{tag["data_time"]:.{self.num_digits}f} ')
|
f'{log_tag["data_time"]:.{self.num_digits}f} ')
|
||||||
# Pop recorded keys
|
# Pop recorded keys
|
||||||
log_tag.pop('time')
|
log_tag.pop('time')
|
||||||
log_tag.pop('data_time')
|
log_tag.pop('data_time')
|
||||||
@ -235,15 +251,8 @@ class LogProcessor:
|
|||||||
'test', 'val'
|
'test', 'val'
|
||||||
], ('`_get_metric_log_str` only accept val or test mode, but got '
|
], ('`_get_metric_log_str` only accept val or test mode, but got '
|
||||||
f'{mode}')
|
f'{mode}')
|
||||||
cur_loop = self._get_cur_loop(runner, mode)
|
dataloader_len = self._get_dataloader_size(runner, mode)
|
||||||
dataloader_len = len(cur_loop.dataloader)
|
|
||||||
|
|
||||||
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
|
|
||||||
# tag is used to write log information to different backends.
|
|
||||||
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
|
|
||||||
non_scalar_tag = self._collect_non_scalars(runner, mode)
|
|
||||||
tag.pop('time', None)
|
|
||||||
tag.pop('data_time', None)
|
|
||||||
# By epoch:
|
# By epoch:
|
||||||
# Epoch(val) [10][1000/1000] ...
|
# Epoch(val) [10][1000/1000] ...
|
||||||
# Epoch(test) [1000/1000] ...
|
# Epoch(test) [1000/1000] ...
|
||||||
@ -261,8 +270,42 @@ class LogProcessor:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
log_str = (f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
|
log_str = (f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
|
||||||
# `time` and `data_time` will not be recorded in after epoch log
|
|
||||||
# message.
|
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
|
||||||
|
# remove prefix
|
||||||
|
custom_keys = [
|
||||||
|
self._remove_prefix(cfg['data_src'], f'{mode}/')
|
||||||
|
for cfg in custom_cfg_copy
|
||||||
|
]
|
||||||
|
# Count the averaged time and data_time by epoch
|
||||||
|
if 'time' not in custom_keys:
|
||||||
|
custom_cfg_copy.append(
|
||||||
|
dict(
|
||||||
|
data_src=f'{mode}/time',
|
||||||
|
window_size='epoch',
|
||||||
|
method_name='mean'))
|
||||||
|
if 'data_time' not in custom_keys:
|
||||||
|
custom_cfg_copy.append(
|
||||||
|
dict(
|
||||||
|
data_src=f'{mode}/data_time',
|
||||||
|
window_size='epoch',
|
||||||
|
method_name='mean'))
|
||||||
|
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
||||||
|
custom_cfg_copy)
|
||||||
|
# tag is used to write log information to different backends.
|
||||||
|
ori_tag = self._collect_scalars(parsed_cfg, runner, mode,
|
||||||
|
self.log_with_hierarchy)
|
||||||
|
non_scalar_tag = self._collect_non_scalars(runner, mode)
|
||||||
|
# move `time` or `data_time` to the end of the log
|
||||||
|
tag = OrderedDict()
|
||||||
|
time_tag = OrderedDict()
|
||||||
|
for key, value in ori_tag.items():
|
||||||
|
if key in (f'{mode}/time', f'{mode}/data_time', 'time',
|
||||||
|
'data_time'):
|
||||||
|
time_tag[key] = value
|
||||||
|
else:
|
||||||
|
tag[key] = value
|
||||||
|
# Log other messages.
|
||||||
log_items = []
|
log_items = []
|
||||||
for name, val in chain(tag.items(), non_scalar_tag.items()):
|
for name, val in chain(tag.items(), non_scalar_tag.items()):
|
||||||
if isinstance(val, float):
|
if isinstance(val, float):
|
||||||
@ -273,12 +316,19 @@ class LogProcessor:
|
|||||||
log_items.append(f'{name}: {val}')
|
log_items.append(f'{name}: {val}')
|
||||||
log_str += ' '.join(log_items)
|
log_str += ' '.join(log_items)
|
||||||
|
|
||||||
|
for name, val in time_tag.items():
|
||||||
|
log_str += f'{name}: {val:.{self.num_digits}f} '
|
||||||
|
|
||||||
if with_non_scalar:
|
if with_non_scalar:
|
||||||
tag.update(non_scalar_tag)
|
tag.update(non_scalar_tag)
|
||||||
|
tag.update(time_tag)
|
||||||
return tag, log_str
|
return tag, log_str
|
||||||
|
|
||||||
def _collect_scalars(self, custom_cfg: List[dict], runner,
|
def _collect_scalars(self,
|
||||||
mode: str) -> dict:
|
custom_cfg: List[dict],
|
||||||
|
runner,
|
||||||
|
mode: str,
|
||||||
|
reserve_prefix: bool = False) -> dict:
|
||||||
"""Collect log information to compose a dict according to mode.
|
"""Collect log information to compose a dict according to mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -287,6 +337,7 @@ class LogProcessor:
|
|||||||
runner (Runner): The runner of the training/testing/validation
|
runner (Runner): The runner of the training/testing/validation
|
||||||
process.
|
process.
|
||||||
mode (str): Current mode of runner.
|
mode (str): Current mode of runner.
|
||||||
|
reserve_prefix (bool): Whether to reserve the prefix of the key.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Statistical values of logs.
|
dict: Statistical values of logs.
|
||||||
@ -300,7 +351,10 @@ class LogProcessor:
|
|||||||
# according to mode.
|
# according to mode.
|
||||||
for prefix_key, log_buffer in history_scalars.items():
|
for prefix_key, log_buffer in history_scalars.items():
|
||||||
if prefix_key.startswith(mode):
|
if prefix_key.startswith(mode):
|
||||||
key = prefix_key.partition('/')[-1]
|
if not reserve_prefix:
|
||||||
|
key = self._remove_prefix(prefix_key, f'{mode}/')
|
||||||
|
else:
|
||||||
|
key = prefix_key
|
||||||
mode_history_scalars[key] = log_buffer
|
mode_history_scalars[key] = log_buffer
|
||||||
for key in mode_history_scalars:
|
for key in mode_history_scalars:
|
||||||
# Update the latest learning rate and smoothed time logs.
|
# Update the latest learning rate and smoothed time logs.
|
||||||
@ -341,10 +395,20 @@ class LogProcessor:
|
|||||||
# extract log info and remove prefix to `mode_infos` according to mode.
|
# extract log info and remove prefix to `mode_infos` according to mode.
|
||||||
for prefix_key, value in infos.items():
|
for prefix_key, value in infos.items():
|
||||||
if prefix_key.startswith(mode):
|
if prefix_key.startswith(mode):
|
||||||
key = prefix_key.partition('/')[-1]
|
if self.log_with_hierarchy:
|
||||||
|
key = prefix_key
|
||||||
|
else:
|
||||||
|
key = self._remove_prefix(prefix_key, f'{mode}/')
|
||||||
mode_infos[key] = value
|
mode_infos[key] = value
|
||||||
return mode_infos
|
return mode_infos
|
||||||
|
|
||||||
|
def _remove_prefix(self, string: str, prefix: str):
|
||||||
|
"""Remove the prefix ``train``, ``val`` and ``test`` of the key."""
|
||||||
|
if string.startswith(prefix):
|
||||||
|
return string[len(prefix):]
|
||||||
|
else:
|
||||||
|
return string
|
||||||
|
|
||||||
def _check_custom_cfg(self) -> None:
|
def _check_custom_cfg(self) -> None:
|
||||||
"""Check the legality of ``self.custom_cfg``."""
|
"""Check the legality of ``self.custom_cfg``."""
|
||||||
|
|
||||||
@ -377,16 +441,24 @@ class LogProcessor:
|
|||||||
_check_repeated_log_name()
|
_check_repeated_log_name()
|
||||||
_check_window_size()
|
_check_window_size()
|
||||||
|
|
||||||
def _parse_windows_size(self, runner, batch_idx: int) -> list:
|
def _parse_windows_size(self,
|
||||||
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
|
custom_cfg: Optional[list] = None) -> list:
|
||||||
"""Parse window_size defined in custom_cfg to int value.
|
"""Parse window_size defined in custom_cfg to int value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training/testing/validation
|
runner (Runner): The runner of the training/testing/validation
|
||||||
process.
|
process.
|
||||||
batch_idx (int): The iteration index of current dataloader.
|
batch_idx (int): The iteration index of current dataloader.
|
||||||
|
custom_cfg (list): A copy of ``self.custom_cfg``. Defaults to None
|
||||||
|
to keep backward compatibility.
|
||||||
"""
|
"""
|
||||||
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
|
if custom_cfg is None:
|
||||||
for log_cfg in custom_cfg_copy:
|
custom_cfg = copy.deepcopy(self.custom_cfg)
|
||||||
|
else:
|
||||||
|
custom_cfg = copy.deepcopy(custom_cfg)
|
||||||
|
for log_cfg in custom_cfg:
|
||||||
window_size = log_cfg.get('window_size', None)
|
window_size = log_cfg.get('window_size', None)
|
||||||
if window_size is None or isinstance(window_size, int):
|
if window_size is None or isinstance(window_size, int):
|
||||||
continue
|
continue
|
||||||
@ -398,7 +470,7 @@ class LogProcessor:
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
'window_size should be int, epoch or global, but got '
|
'window_size should be int, epoch or global, but got '
|
||||||
f'invalid {window_size}')
|
f'invalid {window_size}')
|
||||||
return custom_cfg_copy
|
return custom_cfg
|
||||||
|
|
||||||
def _get_max_memory(self, runner) -> int:
|
def _get_max_memory(self, runner) -> int:
|
||||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
|
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
|
||||||
@ -474,3 +546,15 @@ class LogProcessor:
|
|||||||
return runner.val_loop
|
return runner.val_loop
|
||||||
else:
|
else:
|
||||||
return runner.test_loop
|
return runner.test_loop
|
||||||
|
|
||||||
|
def _get_dataloader_size(self, runner, mode) -> int:
|
||||||
|
"""Get dataloader size of current loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training/validation/testing
|
||||||
|
mode (str): Current mode of runner.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The dataloader size of current loop.
|
||||||
|
"""
|
||||||
|
return len(self._get_cur_loop(runner=runner, mode=mode).dataloader)
|
||||||
|
@ -147,7 +147,12 @@ class TestLoggerHook:
|
|||||||
logger_hook.after_val_epoch(runner)
|
logger_hook.after_val_epoch(runner)
|
||||||
args = {'step': ANY, 'file_path': ANY}
|
args = {'step': ANY, 'file_path': ANY}
|
||||||
# expect visualizer log `time` and `metric` respectively
|
# expect visualizer log `time` and `metric` respectively
|
||||||
runner.visualizer.add_scalars.assert_called_with({'acc': 0.8}, **args)
|
runner.visualizer.add_scalars.assert_called_with(
|
||||||
|
{
|
||||||
|
'time': 1,
|
||||||
|
'datatime': 1,
|
||||||
|
'acc': 0.8
|
||||||
|
}, **args)
|
||||||
|
|
||||||
# Test when `log_metric_by_epoch` is False
|
# Test when `log_metric_by_epoch` is False
|
||||||
logger_hook = LoggerHook(log_metric_by_epoch=False)
|
logger_hook = LoggerHook(log_metric_by_epoch=False)
|
||||||
|
@ -47,35 +47,38 @@ class TestLogProcessor:
|
|||||||
def test_parse_windows_size(self):
|
def test_parse_windows_size(self):
|
||||||
log_processor = LogProcessor()
|
log_processor = LogProcessor()
|
||||||
# Test parse 'epoch' window_size.
|
# Test parse 'epoch' window_size.
|
||||||
log_processor.custom_cfg = [
|
custom_cfg = [dict(data_src='loss_cls', window_size='epoch')]
|
||||||
dict(data_src='loss_cls', window_size='epoch')
|
custom_cfg = log_processor._parse_windows_size(self.runner, 1,
|
||||||
]
|
custom_cfg)
|
||||||
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
|
||||||
assert custom_cfg[0]['window_size'] == 2
|
assert custom_cfg[0]['window_size'] == 2
|
||||||
|
|
||||||
# Test parse 'global' window_size.
|
# Test parse 'global' window_size.
|
||||||
log_processor.custom_cfg = [
|
custom_cfg = [dict(data_src='loss_cls', window_size='global')]
|
||||||
dict(data_src='loss_cls', window_size='global')
|
custom_cfg = log_processor._parse_windows_size(self.runner, 1,
|
||||||
]
|
custom_cfg)
|
||||||
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
|
||||||
assert custom_cfg[0]['window_size'] == 11
|
assert custom_cfg[0]['window_size'] == 11
|
||||||
|
|
||||||
# Test parse int window_size
|
# Test parse int window_size
|
||||||
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)]
|
custom_cfg = [dict(data_src='loss_cls', window_size=100)]
|
||||||
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
custom_cfg = log_processor._parse_windows_size(self.runner, 1,
|
||||||
|
custom_cfg)
|
||||||
assert custom_cfg[0]['window_size'] == 100
|
assert custom_cfg[0]['window_size'] == 100
|
||||||
|
|
||||||
# Invalid type window_size will raise TypeError.
|
# Invalid type window_size will raise TypeError.
|
||||||
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])]
|
custom_cfg = [dict(data_src='loss_cls', window_size=[])]
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
log_processor._parse_windows_size(custom_cfg, self.runner)
|
log_processor._parse_windows_size(self.runner, 1, custom_cfg)
|
||||||
|
|
||||||
@pytest.mark.parametrize('by_epoch,mode',
|
@pytest.mark.parametrize(
|
||||||
([True, 'train'], [False, 'train'], [True, 'val'],
|
'by_epoch,mode,log_with_hierarchy',
|
||||||
[False, 'val'], [True, 'test'], [False, 'test']))
|
([True, 'train', True], [True, 'train', False], [False, 'train', True],
|
||||||
def test_get_log_after_iter(self, by_epoch, mode):
|
[False, 'train', False], [True, 'val', True], [True, 'val', False],
|
||||||
|
[False, 'val', True], [False, 'val', False], [True, 'test', True],
|
||||||
|
[True, 'test', False], [False, 'test', True], [False, 'test', False]))
|
||||||
|
def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
|
||||||
# Prepare LoggerHook
|
# Prepare LoggerHook
|
||||||
log_processor = LogProcessor(by_epoch=by_epoch)
|
log_processor = LogProcessor(
|
||||||
|
by_epoch=by_epoch, log_with_hierarchy=log_with_hierarchy)
|
||||||
log_processor._get_max_memory = MagicMock(return_value='100')
|
log_processor._get_max_memory = MagicMock(return_value='100')
|
||||||
eta = 40
|
eta = 40
|
||||||
self.runner.message_hub.update_info('eta', eta)
|
self.runner.message_hub.update_info('eta', eta)
|
||||||
@ -84,8 +87,9 @@ class TestLogProcessor:
|
|||||||
train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0)
|
train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0)
|
||||||
else:
|
else:
|
||||||
train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
|
train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
|
||||||
log_processor._collect_scalars = MagicMock(return_value=train_logs)
|
log_processor._collect_scalars = \
|
||||||
tag, out = log_processor.get_log_after_iter(self.runner, 1, mode)
|
lambda *args, **kwargs: copy.deepcopy(train_logs)
|
||||||
|
_, out = log_processor.get_log_after_iter(self.runner, 1, mode)
|
||||||
# Verify that the correct context have been logged.
|
# Verify that the correct context have been logged.
|
||||||
cur_loop = log_processor._get_cur_loop(self.runner, mode)
|
cur_loop = log_processor._get_cur_loop(self.runner, mode)
|
||||||
if by_epoch:
|
if by_epoch:
|
||||||
@ -138,11 +142,13 @@ class TestLogProcessor:
|
|||||||
assert out == log_str
|
assert out == log_str
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'by_epoch,mode',
|
'by_epoch,mode,log_with_hierarchy',
|
||||||
([True, 'val'], [False, 'val'], [True, 'test'], [False, 'test']))
|
([True, 'val', True], [True, 'val', False], [False, 'val', True],
|
||||||
def test_log_val(self, by_epoch, mode):
|
[False, 'val', False], [True, 'test', True], [False, 'test', False]))
|
||||||
|
def test_log_val(self, by_epoch, mode, log_with_hierarchy):
|
||||||
# Prepare LoggerHook
|
# Prepare LoggerHook
|
||||||
log_processor = LogProcessor(by_epoch=by_epoch)
|
log_processor = LogProcessor(
|
||||||
|
by_epoch=by_epoch, log_with_hierarchy=log_with_hierarchy)
|
||||||
# Prepare validation information.
|
# Prepare validation information.
|
||||||
scalar_logs = dict(accuracy=0.9, data_time=1.0)
|
scalar_logs = dict(accuracy=0.9, data_time=1.0)
|
||||||
non_scalar_logs = dict(
|
non_scalar_logs = dict(
|
||||||
@ -155,7 +161,7 @@ class TestLogProcessor:
|
|||||||
return_value=non_scalar_logs)
|
return_value=non_scalar_logs)
|
||||||
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
|
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
|
||||||
expect_metric_str = ("accuracy: 0.9000 recall: {'cat': 1, 'dog': 0} "
|
expect_metric_str = ("accuracy: 0.9000 recall: {'cat': 1, 'dog': 0} "
|
||||||
'cm: \ntensor([1, 2, 3])\n')
|
'cm: \ntensor([1, 2, 3])\ndata_time: 1.0000 ')
|
||||||
if by_epoch:
|
if by_epoch:
|
||||||
if mode == 'test':
|
if mode == 'test':
|
||||||
assert out == 'Epoch(test) [5/5] ' + expect_metric_str
|
assert out == 'Epoch(test) [5/5] ' + expect_metric_str
|
||||||
|
Loading…
x
Reference in New Issue
Block a user