mirror of https://github.com/open-mmlab/mmocr.git
275 lines
10 KiB
Python
275 lines
10 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import shutil
|
|
import tempfile
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.config import Config
|
|
from mmengine.evaluator import BaseMetric
|
|
from mmengine.hooks import Hook
|
|
from mmengine.model import BaseModel
|
|
from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS
|
|
from mmengine.runner import Runner
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmocr.engine.runner import MultiTestLoop, MultiValLoop
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ToyModel(BaseModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(2, 2)
|
|
self.linear2 = nn.Linear(2, 1)
|
|
|
|
def forward(self, batch_inputs, labels, mode='tensor'):
|
|
labels = torch.stack(labels)
|
|
outputs = self.linear1(batch_inputs)
|
|
outputs = self.linear2(outputs)
|
|
|
|
if mode == 'tensor':
|
|
return outputs
|
|
elif mode == 'loss':
|
|
loss = (labels - outputs).sum()
|
|
outputs = dict(loss=loss)
|
|
return outputs
|
|
elif mode == 'predict':
|
|
return outputs
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class ToyDataset(Dataset):
|
|
METAINFO = dict() # type: ignore
|
|
data = torch.randn(12, 2)
|
|
label = torch.ones(12)
|
|
|
|
@property
|
|
def metainfo(self):
|
|
return self.METAINFO
|
|
|
|
def __len__(self):
|
|
return self.data.size(0)
|
|
|
|
def __getitem__(self, index):
|
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
|
|
|
|
|
@METRICS.register_module()
|
|
class ToyMetric3(BaseMetric):
|
|
|
|
def __init__(self, collect_device='cpu', prefix=''):
|
|
super().__init__(collect_device=collect_device, prefix=prefix)
|
|
|
|
def process(self, data_samples, predictions):
|
|
result = {'acc': 1}
|
|
self.results.append(result)
|
|
|
|
def compute_metrics(self, results):
|
|
return dict(acc=1)
|
|
|
|
|
|
class TestRunner(TestCase):
|
|
|
|
def setUp(self):
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
epoch_based_cfg = dict(
|
|
default_scope='mmocr',
|
|
model=dict(type='ToyModel'),
|
|
work_dir=self.temp_dir,
|
|
train_dataloader=dict(
|
|
dataset=dict(type='ToyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
val_dataloader=dict(
|
|
dataset=dict(type='ToyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
test_dataloader=dict(
|
|
dataset=dict(type='ToyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
auto_scale_lr=dict(base_batch_size=16, enable=False),
|
|
optim_wrapper=dict(
|
|
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
|
|
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
|
val_evaluator=dict(type='ToyMetric1'),
|
|
test_evaluator=dict(type='ToyMetric1'),
|
|
train_cfg=dict(
|
|
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
|
|
val_cfg=dict(),
|
|
test_cfg=dict(),
|
|
custom_hooks=[],
|
|
default_hooks=dict(
|
|
runtime_info=dict(type='RuntimeInfoHook'),
|
|
timer=dict(type='IterTimerHook'),
|
|
logger=dict(type='LoggerHook'),
|
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
|
checkpoint=dict(
|
|
type='CheckpointHook', interval=1, by_epoch=True),
|
|
sampler_seed=dict(type='DistSamplerSeedHook')),
|
|
launcher='none',
|
|
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
|
)
|
|
self.epoch_based_cfg = Config(epoch_based_cfg)
|
|
self.iter_based_cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
self.iter_based_cfg.train_dataloader = dict(
|
|
dataset=dict(type='ToyDataset'),
|
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0)
|
|
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
|
|
self.iter_based_cfg.default_hooks = dict(
|
|
runtime_info=dict(type='RuntimeInfoHook'),
|
|
timer=dict(type='IterTimerHook'),
|
|
logger=dict(type='LoggerHook'),
|
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
|
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
|
sampler_seed=dict(type='DistSamplerSeedHook'))
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.temp_dir)
|
|
|
|
def test_multi_val_loop(self):
|
|
|
|
before_val_iter_results = []
|
|
after_val_iter_results = []
|
|
multi_metrics = dict()
|
|
|
|
@HOOKS.register_module()
|
|
class Fake_1(Hook):
|
|
"""test custom train loop."""
|
|
|
|
def before_val_iter(self, runner, batch_idx, data_batch=None):
|
|
before_val_iter_results.append('before')
|
|
|
|
def after_val_iter(self,
|
|
runner,
|
|
batch_idx,
|
|
data_batch=None,
|
|
outputs=None):
|
|
after_val_iter_results.append('after')
|
|
|
|
def after_val_epoch(self, runner, metrics=None) -> None:
|
|
multi_metrics.update(metrics)
|
|
|
|
self.iter_based_cfg.val_cfg = dict(type='MultiValLoop')
|
|
self.iter_based_cfg.val_dataloader = [
|
|
dict(
|
|
dataset=dict(type='ToyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
dict(
|
|
dataset=dict(type='ToyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0)
|
|
]
|
|
self.iter_based_cfg.val_evaluator = [
|
|
dict(type='ToyMetric3', prefix='tmp1'),
|
|
dict(type='ToyMetric3', prefix='tmp2')
|
|
]
|
|
self.iter_based_cfg.custom_hooks = [dict(type='Fake_1', priority=50)]
|
|
self.iter_based_cfg.experiment_name = 'test_multi_val_loop'
|
|
runner = Runner.from_cfg(self.iter_based_cfg)
|
|
runner.val()
|
|
|
|
self.assertIsInstance(runner.val_loop, MultiValLoop)
|
|
|
|
# test custom hook triggered as expected
|
|
self.assertEqual(len(before_val_iter_results), 8)
|
|
self.assertEqual(len(after_val_iter_results), 8)
|
|
for before, after in zip(before_val_iter_results,
|
|
after_val_iter_results):
|
|
self.assertEqual(before, 'before')
|
|
self.assertEqual(after, 'after')
|
|
self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1})
|
|
|
|
# test_same prefix
|
|
self.iter_based_cfg.val_evaluator = [
|
|
dict(type='ToyMetric3', prefix='tmp1'),
|
|
dict(type='ToyMetric3', prefix='tmp1')
|
|
]
|
|
self.iter_based_cfg.experiment_name = 'test_multi_val_loop_same_prefix'
|
|
runner = Runner.from_cfg(self.iter_based_cfg)
|
|
with self.assertRaisesRegex(ValueError,
|
|
('Please set different'
|
|
' prefix for different datasets'
|
|
' in `val_evaluator`')):
|
|
runner.val()
|
|
|
|
def test_multi_test_loop(self):
|
|
|
|
before_test_iter_results = []
|
|
after_test_iter_results = []
|
|
multi_metrics = dict()
|
|
|
|
@HOOKS.register_module()
|
|
class Fake_2(Hook):
|
|
"""test custom train loop."""
|
|
|
|
def before_test_iter(self, runner, batch_idx, data_batch=None):
|
|
before_test_iter_results.append('before')
|
|
|
|
def after_test_iter(self,
|
|
runner,
|
|
batch_idx,
|
|
data_batch=None,
|
|
outputs=None):
|
|
after_test_iter_results.append('after')
|
|
|
|
def after_test_epoch(self, runner, metrics=None) -> None:
|
|
multi_metrics.update(metrics)
|
|
|
|
self.iter_based_cfg.test_cfg = dict(type='MultiTestLoop')
|
|
self.iter_based_cfg.test_dataloader = [
|
|
dict(
|
|
dataset=dict(type='ToyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
dict(
|
|
dataset=dict(type='ToyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0)
|
|
]
|
|
self.iter_based_cfg.test_evaluator = [
|
|
dict(type='ToyMetric3', prefix='tmp1'),
|
|
dict(type='ToyMetric3', prefix='tmp2')
|
|
]
|
|
self.iter_based_cfg.custom_hooks = [dict(type='Fake_2', priority=50)]
|
|
self.iter_based_cfg.experiment_name = 'multi_test_loop'
|
|
runner = Runner.from_cfg(self.iter_based_cfg)
|
|
runner.test()
|
|
|
|
self.assertIsInstance(runner.test_loop, MultiTestLoop)
|
|
|
|
# test custom hook triggered as expected
|
|
self.assertEqual(len(before_test_iter_results), 8)
|
|
self.assertEqual(len(after_test_iter_results), 8)
|
|
for before, after in zip(before_test_iter_results,
|
|
after_test_iter_results):
|
|
self.assertEqual(before, 'before')
|
|
self.assertEqual(after, 'after')
|
|
self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1})
|
|
|
|
# test_same prefix
|
|
self.iter_based_cfg.test_evaluator = [
|
|
dict(type='ToyMetric3', prefix='tmp1'),
|
|
dict(type='ToyMetric3', prefix='tmp1')
|
|
]
|
|
self.iter_based_cfg.experiment_name = 'multi_test_loop_same_prefix'
|
|
runner = Runner.from_cfg(self.iter_based_cfg)
|
|
with self.assertRaisesRegex(ValueError,
|
|
('Please set different'
|
|
' prefix for different datasets'
|
|
' in `test_evaluator`')):
|
|
runner.test()
|