mirror of https://github.com/open-mmlab/mmocr.git
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
from mmocr.datasets import UniformConcatDataset
|
|
from mmocr.utils import list_from_file
|
|
|
|
|
|
def test_dataset_warpper():
|
|
pipeline1 = [dict(type='LoadImageFromFile')]
|
|
pipeline2 = [dict(type='LoadImageFromFile'), dict(type='ColorJitter')]
|
|
|
|
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
|
|
ann_file = 'tests/data/ocr_toy_dataset/label.txt'
|
|
train1 = dict(
|
|
type='OCRDataset',
|
|
img_prefix=img_prefix,
|
|
ann_file=ann_file,
|
|
loader=dict(
|
|
type='HardDiskLoader',
|
|
repeat=1,
|
|
parser=dict(
|
|
type='LineStrParser',
|
|
keys=['filename', 'text'],
|
|
keys_idx=[0, 1],
|
|
separator=' ')),
|
|
pipeline=None,
|
|
test_mode=False)
|
|
|
|
train2 = {key: value for key, value in train1.items()}
|
|
train2['pipeline'] = pipeline2
|
|
|
|
uniform_concat_dataset = UniformConcatDataset(
|
|
datasets=[train1, train2], pipeline=pipeline1)
|
|
|
|
assert len(uniform_concat_dataset) == 2 * len(list_from_file(ann_file))
|
|
assert len(uniform_concat_dataset.datasets[0].pipeline.transforms) != len(
|
|
uniform_concat_dataset.datasets[1].pipeline.transforms)
|