mirror of https://github.com/open-mmlab/mmocr.git
131 lines
4.0 KiB
Python
131 lines
4.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
|
|
import pytest
|
|
from mmdet.datasets import DATASETS
|
|
|
|
from mmocr.datasets import UniformConcatDataset
|
|
from mmocr.utils import list_from_file
|
|
|
|
|
|
def test_uniform_concat_dataset_pipeline():
|
|
pipeline1 = [dict(type='LoadImageFromFile')]
|
|
pipeline2 = [dict(type='LoadImageFromFile'), dict(type='ColorJitter')]
|
|
|
|
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
|
|
ann_file = 'tests/data/ocr_toy_dataset/label.txt'
|
|
train1 = dict(
|
|
type='OCRDataset',
|
|
img_prefix=img_prefix,
|
|
ann_file=ann_file,
|
|
loader=dict(
|
|
type='HardDiskLoader',
|
|
repeat=1,
|
|
parser=dict(
|
|
type='LineStrParser',
|
|
keys=['filename', 'text'],
|
|
keys_idx=[0, 1],
|
|
separator=' ')),
|
|
pipeline=None,
|
|
test_mode=False)
|
|
|
|
train2 = {key: value for key, value in train1.items()}
|
|
train2['pipeline'] = pipeline2
|
|
|
|
# pipeline is 1d list
|
|
copy_train1 = copy.deepcopy(train1)
|
|
copy_train2 = copy.deepcopy(train2)
|
|
tmp_dataset = UniformConcatDataset(
|
|
datasets=[copy_train1, copy_train2],
|
|
pipeline=pipeline1,
|
|
force_apply=True)
|
|
|
|
assert len(tmp_dataset) == 2 * len(list_from_file(ann_file))
|
|
assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(
|
|
tmp_dataset.datasets[1].pipeline.transforms)
|
|
|
|
# pipeline is None
|
|
copy_train2 = copy.deepcopy(train2)
|
|
tmp_dataset = UniformConcatDataset(datasets=[copy_train2], pipeline=None)
|
|
assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline2)
|
|
|
|
copy_train2 = copy.deepcopy(train2)
|
|
tmp_dataset = UniformConcatDataset(
|
|
datasets=[[copy_train2], [copy_train2]], pipeline=None)
|
|
assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline2)
|
|
|
|
# pipeline is 2d list
|
|
copy_train1 = copy.deepcopy(train1)
|
|
copy_train2 = copy.deepcopy(train2)
|
|
tmp_dataset = UniformConcatDataset(
|
|
datasets=[[copy_train1], [copy_train2]],
|
|
pipeline=[pipeline1, pipeline2])
|
|
assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline1)
|
|
|
|
|
|
def test_uniform_concat_dataset_eval():
|
|
|
|
@DATASETS.register_module()
|
|
class DummyDataset:
|
|
|
|
def __init__(self):
|
|
self.CLASSES = 0
|
|
self.ann_file = 'empty'
|
|
|
|
def __len__(self):
|
|
return 1
|
|
|
|
def evaluate(self, res, logger, **kwargs):
|
|
return dict(n=res[0])
|
|
|
|
# Test 'auto'
|
|
fake_inputs = [10]
|
|
datasets = [dict(type='DummyDataset')]
|
|
tmp_dataset = UniformConcatDataset(datasets)
|
|
results = tmp_dataset.evaluate(fake_inputs)
|
|
assert results['0_n'] == 10
|
|
assert 'mean_n' not in results
|
|
|
|
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=True)
|
|
results = tmp_dataset.evaluate(fake_inputs)
|
|
assert results['mean_n'] == 10
|
|
|
|
fake_inputs = [10, 20]
|
|
datasets = [dict(type='DummyDataset'), dict(type='DummyDataset')]
|
|
tmp_dataset = UniformConcatDataset(datasets)
|
|
tmp_dataset = UniformConcatDataset(datasets)
|
|
results = tmp_dataset.evaluate(fake_inputs)
|
|
assert results['0_n'] == 10
|
|
assert results['1_n'] == 20
|
|
assert results['mean_n'] == 15
|
|
|
|
tmp_dataset = UniformConcatDataset(datasets, show_mean_scores=False)
|
|
results = tmp_dataset.evaluate(fake_inputs)
|
|
assert results['0_n'] == 10
|
|
assert results['1_n'] == 20
|
|
assert 'mean_n' not in results
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
ds = UniformConcatDataset(datasets, separate_eval=False)
|
|
ds.evaluate(fake_inputs)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
|
|
@DATASETS.register_module()
|
|
class DummyDataset2:
|
|
|
|
def __init__(self):
|
|
self.CLASSES = 0
|
|
self.ann_file = 'empty'
|
|
|
|
def __len__(self):
|
|
return 1
|
|
|
|
def evaluate(self, res, logger, **kwargs):
|
|
return dict(n=res[0])
|
|
|
|
UniformConcatDataset(
|
|
[dict(type='DummyDataset'),
|
|
dict(type='DummyDataset2')],
|
|
show_mean_scores=True)
|