mmocr/tests/test_datasets/test_dataset_wrapper.py

118 lines
4.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from copy import deepcopy
from unittest import TestCase
from unittest.mock import MagicMock
from mmocr.datasets import ConcatDataset, OCRDataset
from mmocr.registry import TRANSFORMS
from mmocr.utils import register_all_modules
class TestConcatDataset(TestCase):
@TRANSFORMS.register_module()
class MockTransform:
def __init__(self, return_value):
self.return_value = return_value
def __call__(self, *args, **kwargs):
return self.return_value
def setUp(self):
register_all_modules()
dataset = OCRDataset
# create dataset_a
data_info = dict(filename='img_1.jpg', height=720, width=1280)
dataset.parse_data_info = MagicMock(return_value=data_info)
self.dataset_a = dataset(
data_root=osp.join(
osp.dirname(__file__), '../data/det_toy_dataset'),
data_prefix=dict(img_path='imgs'),
ann_file='instances_test.json')
self.dataset_a_with_pipeline = dataset(
data_root=osp.join(
osp.dirname(__file__), '../data/det_toy_dataset'),
data_prefix=dict(img_path='imgs'),
ann_file='instances_test.json',
pipeline=[dict(type='MockTransform', return_value=1)])
# create dataset_b
data_info = dict(filename='img_2.jpg', height=720, width=1280)
dataset.parse_data_info = MagicMock(return_value=data_info)
self.dataset_b = dataset(
data_root=osp.join(
osp.dirname(__file__), '../data/det_toy_dataset'),
data_prefix=dict(img_path='imgs'),
ann_file='instances_test.json')
self.dataset_b_with_pipeline = dataset(
data_root=osp.join(
osp.dirname(__file__), '../data/det_toy_dataset'),
data_prefix=dict(img_path='imgs'),
ann_file='instances_test.json',
pipeline=[dict(type='MockTransform', return_value=2)])
def test_init(self):
with self.assertRaises(TypeError):
ConcatDataset(datasets=[0])
with self.assertRaises(ValueError):
ConcatDataset(
datasets=[
deepcopy(self.dataset_a_with_pipeline),
deepcopy(self.dataset_b)
],
pipeline=[dict(type='MockTransform', return_value=3)])
with self.assertRaises(ValueError):
ConcatDataset(
datasets=[
deepcopy(self.dataset_a),
deepcopy(self.dataset_b_with_pipeline)
],
pipeline=[dict(type='MockTransform', return_value=3)])
with self.assertRaises(ValueError):
dataset_a = deepcopy(self.dataset_a)
dataset_b = OCRDataset(
metainfo=dict(dummy='dummy'),
data_root=osp.join(
osp.dirname(__file__), '../data/det_toy_dataset'),
data_prefix=dict(img_path='imgs'),
ann_file='instances_test.json')
ConcatDataset(datasets=[dataset_a, dataset_b])
# test lazy init
ConcatDataset(
datasets=[deepcopy(self.dataset_a),
deepcopy(self.dataset_b)],
pipeline=[dict(type='MockTransform', return_value=3)],
lazy_init=True)
def test_getitem(self):
cat_datasets = ConcatDataset(
datasets=[deepcopy(self.dataset_a),
deepcopy(self.dataset_b)],
pipeline=[dict(type='MockTransform', return_value=3)])
for datum in cat_datasets:
self.assertEqual(datum, 3)
cat_datasets = ConcatDataset(
datasets=[
deepcopy(self.dataset_a_with_pipeline),
deepcopy(self.dataset_b)
],
pipeline=[dict(type='MockTransform', return_value=3)],
force_apply=True)
for datum in cat_datasets:
self.assertEqual(datum, 3)
cat_datasets = ConcatDataset(datasets=[
deepcopy(self.dataset_a_with_pipeline),
deepcopy(self.dataset_b_with_pipeline)
])
self.assertEqual(cat_datasets[0], 1)
self.assertEqual(cat_datasets[-1], 2)