mmocr/tests/test_engine/test_runner/test_multi_loop.py

275 lines
10 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import shutil
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.config import Config
from mmengine.evaluator import BaseMetric
from mmengine.hooks import Hook
from mmengine.model import BaseModel
from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS
from mmengine.runner import Runner
from torch.utils.data import Dataset
from mmocr.engine.runner import MultiTestLoop, MultiValLoop
@MODELS.register_module()
class ToyModel(BaseModel):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 1)
def forward(self, batch_inputs, labels, mode='tensor'):
labels = torch.stack(labels)
outputs = self.linear1(batch_inputs)
outputs = self.linear2(outputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
return outputs
@DATASETS.register_module()
class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
@METRICS.register_module()
class ToyMetric3(BaseMetric):
def __init__(self, collect_device='cpu', prefix=''):
super().__init__(collect_device=collect_device, prefix=prefix)
def process(self, data_samples, predictions):
result = {'acc': 1}
self.results.append(result)
def compute_metrics(self, results):
return dict(acc=1)
class TestRunner(TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
epoch_based_cfg = dict(
default_scope='mmocr',
model=dict(type='ToyModel'),
work_dir=self.temp_dir,
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
auto_scale_lr=dict(base_batch_size=16, enable=False),
optim_wrapper=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyMetric1'),
train_cfg=dict(
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
val_cfg=dict(),
test_cfg=dict(),
custom_hooks=[],
default_hooks=dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=1, by_epoch=True),
sampler_seed=dict(type='DistSamplerSeedHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
)
self.epoch_based_cfg = Config(epoch_based_cfg)
self.iter_based_cfg = copy.deepcopy(self.epoch_based_cfg)
self.iter_based_cfg.train_dataloader = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='InfiniteSampler', shuffle=True),
batch_size=3,
num_workers=0)
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
self.iter_based_cfg.default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
sampler_seed=dict(type='DistSamplerSeedHook'))
def tearDown(self):
shutil.rmtree(self.temp_dir)
def test_multi_val_loop(self):
before_val_iter_results = []
after_val_iter_results = []
multi_metrics = dict()
@HOOKS.register_module()
class Fake_1(Hook):
"""test custom train loop."""
def before_val_iter(self, runner, batch_idx, data_batch=None):
before_val_iter_results.append('before')
def after_val_iter(self,
runner,
batch_idx,
data_batch=None,
outputs=None):
after_val_iter_results.append('after')
def after_val_epoch(self, runner, metrics=None) -> None:
multi_metrics.update(metrics)
self.iter_based_cfg.val_cfg = dict(type='MultiValLoop')
self.iter_based_cfg.val_dataloader = [
dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0)
]
self.iter_based_cfg.val_evaluator = [
dict(type='ToyMetric3', prefix='tmp1'),
dict(type='ToyMetric3', prefix='tmp2')
]
self.iter_based_cfg.custom_hooks = [dict(type='Fake_1', priority=50)]
self.iter_based_cfg.experiment_name = 'test_multi_val_loop'
runner = Runner.from_cfg(self.iter_based_cfg)
runner.val()
self.assertIsInstance(runner.val_loop, MultiValLoop)
# test custom hook triggered as expected
self.assertEqual(len(before_val_iter_results), 8)
self.assertEqual(len(after_val_iter_results), 8)
for before, after in zip(before_val_iter_results,
after_val_iter_results):
self.assertEqual(before, 'before')
self.assertEqual(after, 'after')
self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1})
# test_same prefix
self.iter_based_cfg.val_evaluator = [
dict(type='ToyMetric3', prefix='tmp1'),
dict(type='ToyMetric3', prefix='tmp1')
]
self.iter_based_cfg.experiment_name = 'test_multi_val_loop_same_prefix'
runner = Runner.from_cfg(self.iter_based_cfg)
with self.assertRaisesRegex(ValueError,
('Please set different'
' prefix for different datasets'
' in `val_evaluator`')):
runner.val()
def test_multi_test_loop(self):
before_test_iter_results = []
after_test_iter_results = []
multi_metrics = dict()
@HOOKS.register_module()
class Fake_2(Hook):
"""test custom train loop."""
def before_test_iter(self, runner, batch_idx, data_batch=None):
before_test_iter_results.append('before')
def after_test_iter(self,
runner,
batch_idx,
data_batch=None,
outputs=None):
after_test_iter_results.append('after')
def after_test_epoch(self, runner, metrics=None) -> None:
multi_metrics.update(metrics)
self.iter_based_cfg.test_cfg = dict(type='MultiTestLoop')
self.iter_based_cfg.test_dataloader = [
dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0)
]
self.iter_based_cfg.test_evaluator = [
dict(type='ToyMetric3', prefix='tmp1'),
dict(type='ToyMetric3', prefix='tmp2')
]
self.iter_based_cfg.custom_hooks = [dict(type='Fake_2', priority=50)]
self.iter_based_cfg.experiment_name = 'multi_test_loop'
runner = Runner.from_cfg(self.iter_based_cfg)
runner.test()
self.assertIsInstance(runner.test_loop, MultiTestLoop)
# test custom hook triggered as expected
self.assertEqual(len(before_test_iter_results), 8)
self.assertEqual(len(after_test_iter_results), 8)
for before, after in zip(before_test_iter_results,
after_test_iter_results):
self.assertEqual(before, 'before')
self.assertEqual(after, 'after')
self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1})
# test_same prefix
self.iter_based_cfg.test_evaluator = [
dict(type='ToyMetric3', prefix='tmp1'),
dict(type='ToyMetric3', prefix='tmp1')
]
self.iter_based_cfg.experiment_name = 'multi_test_loop_same_prefix'
runner = Runner.from_cfg(self.iter_based_cfg)
with self.assertRaisesRegex(ValueError,
('Please set different'
' prefix for different datasets'
' in `test_evaluator`')):
runner.test()