mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix the bug that the training log and evaluating log are mixed (#1252)
* [Fix] Fix the bug that training log and evaluating log are mixed * [Fix] Fix the bug that training log and evaluating log are mixed * fix comment * fix import error * refactor * refactor * refactor * clear log_buffer before evaluation * fix error * add unittestpull/1259/head
parent
18c64d5fb0
commit
846d3a4ac6
|
@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from mmcv.utils import is_seq_of
|
||||
from .hook import Hook
|
||||
from .logger import LoggerHook
|
||||
|
||||
|
||||
class EvalHook(Hook):
|
||||
|
@ -212,19 +213,31 @@ class EvalHook(Hook):
|
|||
|
||||
def after_train_iter(self, runner):
|
||||
"""Called after every training iter to evaluate the results."""
|
||||
if not self.by_epoch:
|
||||
if not self.by_epoch and self._should_evaluate(runner):
|
||||
# Because the priority of EvalHook is higher than LoggerHook, the
|
||||
# training log and the evaluating log are mixed. Therefore,
|
||||
# we need to dump the training log and clear it before evaluating
|
||||
# log is generated. In addition, this problem will only appear in
|
||||
# `IterBasedRunner` whose `self.by_epoch` is False, because
|
||||
# `EpochBasedRunner` whose `self.by_epoch` is True calls
|
||||
# `_do_evaluate` in `after_train_epoch` stage, and at this stage
|
||||
# the training log has been printed, so it will not cause any
|
||||
# problem. more details at
|
||||
# https://github.com/open-mmlab/mmsegmentation/issues/694
|
||||
for hook in runner._hooks:
|
||||
if isinstance(hook, LoggerHook):
|
||||
hook.after_train_iter(runner)
|
||||
runner.log_buffer.clear()
|
||||
|
||||
self._do_evaluate(runner)
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
"""Called after every training epoch to evaluate the results."""
|
||||
if self.by_epoch:
|
||||
if self.by_epoch and self._should_evaluate(runner):
|
||||
self._do_evaluate(runner)
|
||||
|
||||
def _do_evaluate(self, runner):
|
||||
"""perform evaluation and save ckpt."""
|
||||
if not self._should_evaluate(runner):
|
||||
return
|
||||
|
||||
results = self.test_fn(runner.model, self.dataloader)
|
||||
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
||||
key_score = self.evaluate(runner, results)
|
||||
|
@ -419,9 +432,6 @@ class DistEvalHook(EvalHook):
|
|||
dist.broadcast(module.running_var, 0)
|
||||
dist.broadcast(module.running_mean, 0)
|
||||
|
||||
if not self._should_evaluate(runner):
|
||||
return
|
||||
|
||||
tmpdir = self.tmpdir
|
||||
if tmpdir is None:
|
||||
tmpdir = osp.join(runner.work_dir, '.eval_hook')
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import unittest.mock as mock
|
||||
|
@ -7,13 +8,14 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmcv.runner import DistEvalHook as BaseDistEvalHook
|
||||
from mmcv.runner import EpochBasedRunner
|
||||
from mmcv.runner import EvalHook as BaseEvalHook
|
||||
from mmcv.runner import IterBasedRunner
|
||||
from mmcv.utils import get_logger
|
||||
from mmcv.utils import get_logger, scandir
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
@ -48,18 +50,16 @@ class Model(nn.Module):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
self.param = nn.Parameter(torch.tensor([1.0]))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return x
|
||||
return self.param * x
|
||||
|
||||
def train_step(self, data_batch, optimizer, **kwargs):
|
||||
if not isinstance(data_batch, dict):
|
||||
data_batch = dict(x=data_batch)
|
||||
return data_batch
|
||||
return {'loss': torch.sum(self(data_batch['x']))}
|
||||
|
||||
def val_step(self, x, optimizer, **kwargs):
|
||||
return dict(loss=self(x))
|
||||
def val_step(self, data_batch, optimizer, **kwargs):
|
||||
return {'loss': torch.sum(self(data_batch['x']))}
|
||||
|
||||
|
||||
def _build_epoch_runner():
|
||||
|
@ -307,7 +307,7 @@ def test_eval_hook():
|
|||
(_build_iter_runner, False)])
|
||||
def test_start_param(EvalHookParam, _build_demo_runner, by_epoch):
|
||||
# create dummy data
|
||||
dataloader = DataLoader(torch.ones((5, 2)))
|
||||
dataloader = DataLoader(EvalDataset())
|
||||
|
||||
# 0.1. dataloader is not a DataLoader object
|
||||
with pytest.raises(TypeError):
|
||||
|
@ -389,3 +389,36 @@ def test_start_param(EvalHookParam, _build_demo_runner, by_epoch):
|
|||
runner._iter = 1
|
||||
runner.run([dataloader], [('train', 1)], 3)
|
||||
assert evalhook.evaluate.call_count == 2 # after epoch 2 & 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize('runner,by_epoch,eval_hook_priority',
|
||||
[(EpochBasedRunner, True, 'NORMAL'),
|
||||
(EpochBasedRunner, True, 'LOW'),
|
||||
(IterBasedRunner, False, 'LOW')])
|
||||
def test_logger(runner, by_epoch, eval_hook_priority):
|
||||
loader = DataLoader(EvalDataset())
|
||||
model = Model()
|
||||
data_loader = DataLoader(EvalDataset())
|
||||
eval_hook = EvalHook(
|
||||
data_loader, interval=1, by_epoch=by_epoch, save_best='acc')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
logger = get_logger('test_logger')
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
||||
runner = EpochBasedRunner(
|
||||
model=model, optimizer=optimizer, work_dir=tmpdir, logger=logger)
|
||||
runner.register_logger_hooks(
|
||||
dict(
|
||||
interval=1,
|
||||
hooks=[dict(type='TextLoggerHook', by_epoch=by_epoch)]))
|
||||
runner.register_timer_hook(dict(type='IterTimerHook'))
|
||||
runner.register_hook(eval_hook, priority=eval_hook_priority)
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
|
||||
path = osp.join(tmpdir, next(scandir(tmpdir, '.json')))
|
||||
with open(path) as fr:
|
||||
fr.readline() # skip first line which is hook_msg
|
||||
train_log = json.loads(fr.readline())
|
||||
assert train_log['mode'] == 'train' and 'time' in train_log
|
||||
val_log = json.loads(fr.readline())
|
||||
assert val_log['mode'] == 'val' and 'time' not in val_log
|
||||
|
|
Loading…
Reference in New Issue