[Fix]: fix log processor to log average time and grad norm (#292)

This commit is contained in:
Mashiro 2022-06-17 10:54:20 +08:00 committed by GitHub
parent 7b55c5bdbf
commit 7129a98e36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 11 deletions

View File

@ -260,7 +260,8 @@ class LogProcessor:
mode_history_scalars[key] = log_buffer mode_history_scalars[key] = log_buffer
for key in mode_history_scalars: for key in mode_history_scalars:
# Update the latest learning rate and smoothed time logs. # Update the latest learning rate and smoothed time logs.
if key.startswith('loss'): if key.startswith('loss') or key in ('time', 'data_time',
'grad_norm'):
tag[key] = mode_history_scalars[key].mean(self.window_size) tag[key] = mode_history_scalars[key].mean(self.window_size)
else: else:
# Default statistic method is current. # Default statistic method is current.

View File

@ -2,10 +2,11 @@
import copy import copy
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import numpy as np
import pytest import pytest
import torch import torch
from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.logging import HistoryBuffer, LogProcessor, MessageHub, MMLogger
class TestLogProcessor: class TestLogProcessor:
@ -154,17 +155,27 @@ class TestLogProcessor:
assert out == 'Iter(val) [10/10] accuracy: 0.9000' assert out == 'Iter(val) [10/10] accuracy: 0.9000'
def test_collect_scalars(self): def test_collect_scalars(self):
history_count = np.ones(100)
time_scalars = np.random.randn(100)
loss_cls_scalars = np.random.randn(100)
lr_scalars = np.random.randn(100)
metric_scalars = np.random.randn(100)
history_time_buffer = HistoryBuffer(time_scalars, history_count)
histroy_loss_cls = HistoryBuffer(loss_cls_scalars, history_count)
history_lr_buffer = HistoryBuffer(lr_scalars, history_count)
history_metric_buffer = HistoryBuffer(metric_scalars, history_count)
custom_cfg = [ custom_cfg = [
dict(data_src='time', method_name='mean', window_size=100),
dict(data_src='time', method_name='max', log_name='time_max') dict(data_src='time', method_name='max', log_name='time_max')
] ]
logger_hook = LogProcessor(custom_cfg=custom_cfg) logger_hook = LogProcessor(custom_cfg=custom_cfg)
# Collect with prefix. # Collect with prefix.
log_scalars = { log_scalars = {
'train/time': MagicMock(), 'train/time': history_time_buffer,
'lr': MagicMock(), 'lr': history_lr_buffer,
'train/loss_cls': MagicMock(), 'train/loss_cls': histroy_loss_cls,
'val/metric': MagicMock() 'val/metric': history_metric_buffer
} }
self.runner.message_hub._log_scalars = log_scalars self.runner.message_hub._log_scalars = log_scalars
tag = logger_hook._collect_scalars( tag = logger_hook._collect_scalars(
@ -172,14 +183,14 @@ class TestLogProcessor:
# Test training key in tag. # Test training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max'] assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
# Test statistics lr with `current`, loss and time with 'mean' # Test statistics lr with `current`, loss and time with 'mean'
log_scalars['train/time'].statistics.assert_called_with( assert tag['time'] == time_scalars[-10:].mean()
method_name='max') assert tag['time_max'] == time_scalars.max()
log_scalars['train/loss_cls'].mean.assert_called() assert tag['loss_cls'] == loss_cls_scalars[-10:].mean()
tag = logger_hook._collect_scalars( tag = logger_hook._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='val') copy.deepcopy(custom_cfg), self.runner, mode='val')
assert list(tag.keys()) == ['metric'] assert list(tag.keys()) == ['metric']
log_scalars['val/metric'].current.assert_called() assert tag['metric'] == metric_scalars[-1]
@patch('torch.cuda.max_memory_allocated', MagicMock()) @patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock()) @patch('torch.cuda.reset_peak_memory_stats', MagicMock())