mmocr/tests/test_datasets/test_transforms/test_loading.py

218 lines
7.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from unittest import TestCase
import mmcv
import numpy as np
from mmocr.datasets.transforms import (InferencerLoader, LoadImageFromFile,
LoadKIEAnnotations, LoadOCRAnnotations)
class TestLoadImageFromFile(TestCase):
def test_load_img(self):
data_prefix = osp.join(
osp.dirname(__file__), '../../data/rec_toy_dataset/imgs/')
results = dict(img_path=osp.join(data_prefix, '1036169.jpg'))
transform = LoadImageFromFile(min_size=0)
results = transform(copy.deepcopy(results))
self.assertEquals(results['img_path'],
osp.join(data_prefix, '1036169.jpg'))
self.assertEquals(results['img'].shape, (25, 119, 3))
self.assertEquals(results['img'].dtype, np.uint8)
self.assertEquals(results['img_shape'], (25, 119))
self.assertEquals(results['ori_shape'], (25, 119))
self.assertEquals(
repr(transform),
('LoadImageFromFile(ignore_empty=False, min_size=0, '
"to_float32=False, color_type='color', imdecode_backend='cv2', "
'backend_args=None)'))
# to_float32
transform = LoadImageFromFile(to_float32=True)
results = transform(copy.deepcopy(results))
self.assertEquals(results['img'].dtype, np.float32)
# min_size
transform = LoadImageFromFile(min_size=26, ignore_empty=True)
self.assertIsNone(transform(copy.deepcopy(results)))
transform = LoadImageFromFile(min_size=26)
with self.assertRaises(IOError):
transform(copy.deepcopy(results))
# test load empty
fake_img_path = osp.join(data_prefix, 'fake.jpg')
results = dict(img_path=fake_img_path)
transform = LoadImageFromFile(ignore_empty=False)
with self.assertRaises(FileNotFoundError):
transform(copy.deepcopy(results))
transform = LoadImageFromFile(ignore_empty=True)
results = transform(copy.deepcopy(results))
self.assertIsNone(results)
data_prefix = osp.join(osp.dirname(__file__), '../../data')
broken_img_path = osp.join(data_prefix, 'broken.jpg')
results = dict(img_path=broken_img_path)
transform = LoadImageFromFile(ignore_empty=False)
with self.assertRaises(IOError):
transform(copy.deepcopy(results))
transform = LoadImageFromFile(ignore_empty=True)
results = transform(copy.deepcopy(results))
self.assertIsNone(results)
class TestLoadOCRAnnotations(TestCase):
def setUp(self):
self.results = {
'height':
288,
'width':
512,
'instances': [{
'bbox': [0, 0, 10, 20],
'bbox_label': 1,
'polygon': [0, 0, 0, 20, 10, 20, 10, 0],
'text': 'tmp1',
'ignore': False
}, {
'bbox': [10, 10, 110, 120],
'bbox_label': 2,
'polygon': [10, 10, 10, 120, 110, 120, 110, 10],
'text': 'tmp2',
'ignore': False
}, {
'bbox': [0, 0, 10, 20],
'bbox_label': 1,
'polygon': [0, 0, 0, 20, 10, 20, 10, 0],
'text': 'tmp3',
'ignore': True
}, {
'bbox': [10, 10, 110, 120],
'bbox_label': 2,
'polygon': [10, 10, 10, 120, 110, 120, 110, 10],
'text': 'tmp4',
'ignore': True
}]
}
def test_load_polygon(self):
transform = LoadOCRAnnotations(
with_bbox=False, with_label=False, with_polygon=True)
results = transform(copy.deepcopy(self.results))
self.assertIn('gt_polygons', results)
self.assertIsInstance(results['gt_polygons'], list)
self.assertEqual(len(results['gt_polygons']), 4)
for gt_polygon in results['gt_polygons']:
self.assertIsInstance(gt_polygon, np.ndarray)
self.assertEqual(gt_polygon.dtype, np.float32)
self.assertIn('gt_ignored', results)
self.assertEqual(results['gt_ignored'].dtype, np.bool_)
self.assertTrue((results['gt_ignored'],
np.array([False, False, True, True], dtype=np.bool_)))
def test_load_text(self):
transform = LoadOCRAnnotations(
with_bbox=False, with_label=False, with_text=True)
results = transform(copy.deepcopy(self.results))
self.assertIn('gt_texts', results)
self.assertListEqual(results['gt_texts'],
['tmp1', 'tmp2', 'tmp3', 'tmp4'])
def test_repr(self):
transform = LoadOCRAnnotations(
with_bbox=True, with_label=True, with_polygon=True, with_text=True)
self.assertEqual(
repr(transform),
('LoadOCRAnnotations(with_bbox=True, with_label=True, '
'with_polygon=True, with_text=True, '
"imdecode_backend='cv2', backend_args=None)"))
class TestLoadKIEAnnotations(TestCase):
def setUp(self):
self.results = {
'bboxes': np.random.rand(2, 4).astype(np.float32),
'bbox_labels': np.random.randint(0, 10, (2, )),
'edge_labels': np.random.randint(0, 10, (2, 2)),
'texts': ['text1', 'text2'],
'ori_shape': (288, 512)
}
self.results = {
'img_shape': (288, 512),
'ori_shape': (288, 512),
'instances': [{
'bbox': [0, 0, 10, 20],
'bbox_label': 1,
'edge_label': 1,
'text': 'tmp1',
}, {
'bbox': [10, 10, 110, 120],
'bbox_label': 2,
'edge_label': 1,
'text': 'tmp2',
}]
}
self.load = LoadKIEAnnotations()
def test_transform(self):
results = self.load(copy.deepcopy(self.results))
self.assertIn('gt_bboxes', results)
self.assertIsInstance(results['gt_bboxes'], np.ndarray)
self.assertEqual(results['gt_bboxes'].shape, (2, 4))
self.assertEqual(results['gt_bboxes'].dtype, np.float32)
self.assertIn('gt_bboxes_labels', results)
self.assertIsInstance(results['gt_bboxes_labels'], np.ndarray)
self.assertEqual(results['gt_bboxes_labels'].shape, (2, ))
self.assertEqual(results['gt_bboxes_labels'].dtype, np.int64)
self.assertIn('gt_edges_labels', results)
self.assertIsInstance(results['gt_edges_labels'], np.ndarray)
self.assertEqual(results['gt_edges_labels'].shape, (2, 2))
self.assertEqual(results['gt_edges_labels'].dtype, np.int64)
self.assertIn('ori_shape', results)
self.assertEqual(results['ori_shape'], (288, 512))
load = LoadKIEAnnotations(key_node_idx=1, value_node_idx=2)
results = load(copy.deepcopy(self.results))
self.assertEqual(results['gt_edges_labels'][0, 1], 1)
self.assertEqual(results['gt_edges_labels'][1, 0], -1)
def test_repr(self):
self.assertEqual(
repr(self.load),
'LoadKIEAnnotations(with_bbox=True, with_label=True, '
'with_text=True)')
class TestInferencerLoader(TestCase):
def test_transform(self):
loader = InferencerLoader()
# load from path
img_path = 'tests/data/det_toy_dataset/imgs/test/img_1.jpg'
res = loader(img_path)
self.assertIsInstance(res['img'], np.ndarray)
# load from ndarray
img = mmcv.imread(img_path)
res = loader(img)
self.assertIsInstance(res['img'], np.ndarray)
# load from dict
res = loader(dict(img=img))
self.assertIsInstance(res['img'], np.ndarray)
# invalid input
with self.assertRaises(NotImplementedError):
loader(['hello'])