mirror of https://github.com/open-mmlab/mmocr.git
163 lines
6.3 KiB
Python
163 lines
6.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import unittest
|
|
from typing import Dict, List, Optional
|
|
|
|
import numpy as np
|
|
from shapely.geometry import Polygon
|
|
|
|
from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper
|
|
|
|
|
|
class TestImgAug(unittest.TestCase):
|
|
|
|
def test_init(self):
|
|
with self.assertRaises(AssertionError):
|
|
ImgAugWrapper(args=[])
|
|
with self.assertRaises(AssertionError):
|
|
ImgAugWrapper(args=['test'])
|
|
|
|
def _create_dummy_data(self):
|
|
img = np.random.rand(50, 50, 3)
|
|
poly = np.array([[[0, 0, 50, 0, 50, 50, 0, 50]],
|
|
[[20, 20, 50, 20, 50, 50, 20, 50]]])
|
|
box = np.array([[0, 0, 50, 50], [20, 20, 50, 50]])
|
|
# It shall always be 0 in MMOCR, but we assign different labels to
|
|
# dummy instances for testing
|
|
labels = np.array([0, 1], dtype=np.int64)
|
|
ignored = np.array([False, True], dtype=bool)
|
|
texts = ['text1', 'text2']
|
|
return dict(
|
|
img=img,
|
|
img_shape=(50, 50),
|
|
gt_polygons=poly,
|
|
gt_bboxes=box,
|
|
gt_bboxes_labels=labels,
|
|
gt_ignored=ignored,
|
|
gt_texts=texts)
|
|
|
|
def assertPolyEqual(self, poly1: List[np.ndarray],
|
|
poly2: List[np.ndarray]) -> None:
|
|
for p1, p2 in zip(poly1, poly2):
|
|
self.assertTrue(
|
|
Polygon(p1.reshape(-1, 2)).equals(Polygon(p2.reshape(-1, 2))))
|
|
|
|
def assert_result_equal(self,
|
|
results: Dict,
|
|
poly_targets: List[np.ndarray],
|
|
bbox_targets: np.ndarray,
|
|
bbox_label_targets: np.ndarray,
|
|
ignore_targets: np.ndarray,
|
|
text_targets: Optional[List[str]] = None) -> None:
|
|
self.assertPolyEqual(poly_targets, results['gt_polygons'])
|
|
self.assertTrue(np.array_equal(bbox_targets, results['gt_bboxes']))
|
|
self.assertTrue(
|
|
np.array_equal(bbox_label_targets, results['gt_bboxes_labels']))
|
|
self.assertTrue(np.array_equal(ignore_targets, results['gt_ignored']))
|
|
self.assertEqual(text_targets, results['gt_texts'])
|
|
self.assertEqual(results['img_shape'],
|
|
(results['img'].shape[0], results['img'].shape[1]))
|
|
|
|
def test_transform(self):
|
|
|
|
# Test empty transform
|
|
imgaug_transform = ImgAugWrapper()
|
|
results = self._create_dummy_data()
|
|
origin_results = copy.deepcopy(results)
|
|
results = imgaug_transform(results)
|
|
self.assert_result_equal(results, origin_results['gt_polygons'],
|
|
origin_results['gt_bboxes'],
|
|
origin_results['gt_bboxes_labels'],
|
|
origin_results['gt_ignored'],
|
|
origin_results['gt_texts'])
|
|
|
|
args = [dict(cls='Affine', translate_px=dict(x=-10, y=-10))]
|
|
imgaug_transform = ImgAugWrapper(args)
|
|
results = self._create_dummy_data()
|
|
results = imgaug_transform(results)
|
|
|
|
# Polygons and bboxes are partially outside the image after
|
|
# transformation
|
|
poly_target = [
|
|
np.array([0, 0, 40, 0, 40, 40, 0, 40]),
|
|
np.array([10, 10, 40, 10, 40, 40, 10, 40])
|
|
]
|
|
box_target = np.array([[0, 0, 40, 40], [10, 10, 40, 40]])
|
|
label_target = np.array([0, 1], dtype=np.int64)
|
|
ignored = np.array([False, True], dtype=bool)
|
|
texts = ['text1', 'text2']
|
|
self.assert_result_equal(results, poly_target, box_target,
|
|
label_target, ignored, texts)
|
|
|
|
# Some polygons and bboxes are no longer inside the image after
|
|
# transformation
|
|
args = [
|
|
dict(cls='Affine', translate_px=dict(x=30, y=30)), ['Fliplr', 1]
|
|
]
|
|
poly_target = [np.array([0, 30, 20, 30, 20, 50, 0, 50])]
|
|
box_target = np.array([[0, 30, 20, 50]])
|
|
label_target = np.array([0], dtype=np.int64)
|
|
ignored = np.array([False], dtype=bool)
|
|
texts = ['text1']
|
|
imgaug_transform = ImgAugWrapper(args)
|
|
results = self._create_dummy_data()
|
|
results = imgaug_transform(results)
|
|
self.assert_result_equal(results, poly_target, box_target,
|
|
label_target, ignored, texts)
|
|
|
|
# All polygons and bboxes are no longer inside the image after
|
|
# transformation
|
|
|
|
# When some transforms result in empty polygons
|
|
args = [dict(cls='Affine', translate_px=dict(x=100, y=100))]
|
|
results = self._create_dummy_data()
|
|
invalid_transform = ImgAugWrapper(args)
|
|
results = invalid_transform(results)
|
|
self.assertIsNone(results)
|
|
|
|
# Everything should work well without gt_texts
|
|
results = self._create_dummy_data()
|
|
del results['gt_texts']
|
|
results = imgaug_transform(results)
|
|
self.assertNotIn('gt_texts', results)
|
|
|
|
# Everything should work well without keys required from text detection
|
|
results = imgaug_transform(
|
|
dict(
|
|
img=np.random.rand(10, 20, 3),
|
|
img_shape=(10, 20),
|
|
gt_texts=['text1', 'text2']))
|
|
self.assertEqual(results['gt_texts'], ['text1', 'text2'])
|
|
|
|
def test_repr(self):
|
|
args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]]
|
|
transform = ImgAugWrapper(args)
|
|
print(repr(transform))
|
|
self.assertEqual(
|
|
repr(transform),
|
|
("ImgAugWrapper(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"
|
|
))
|
|
|
|
|
|
class TestTorchVisionWrapper(unittest.TestCase):
|
|
|
|
def test_transform(self):
|
|
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
|
|
# object not found error
|
|
with self.assertRaises(Exception):
|
|
TorchVisionWrapper(op='NonExist')
|
|
with self.assertRaises(TypeError):
|
|
TorchVisionWrapper()
|
|
f = TorchVisionWrapper('Grayscale')
|
|
with self.assertRaises(AssertionError):
|
|
f({})
|
|
results = f(x)
|
|
assert results['img'].shape == (128, 100)
|
|
assert results['img_shape'] == (128, 100)
|
|
|
|
def test_repr(self):
|
|
f = TorchVisionWrapper('Grayscale', num_output_channels=3)
|
|
self.assertEqual(
|
|
repr(f),
|
|
'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')
|