mmocr/tests/test_datasets/test_transforms/test_wrappers.py

163 lines
6.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest
from typing import Dict, List, Optional
import numpy as np
from shapely.geometry import Polygon
from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper
class TestImgAug(unittest.TestCase):
def test_init(self):
with self.assertRaises(AssertionError):
ImgAugWrapper(args=[])
with self.assertRaises(AssertionError):
ImgAugWrapper(args=['test'])
def _create_dummy_data(self):
img = np.random.rand(50, 50, 3)
poly = np.array([[[0, 0, 50, 0, 50, 50, 0, 50]],
[[20, 20, 50, 20, 50, 50, 20, 50]]])
box = np.array([[0, 0, 50, 50], [20, 20, 50, 50]])
# It shall always be 0 in MMOCR, but we assign different labels to
# dummy instances for testing
labels = np.array([0, 1], dtype=np.int64)
ignored = np.array([False, True], dtype=bool)
texts = ['text1', 'text2']
return dict(
img=img,
img_shape=(50, 50),
gt_polygons=poly,
gt_bboxes=box,
gt_bboxes_labels=labels,
gt_ignored=ignored,
gt_texts=texts)
def assertPolyEqual(self, poly1: List[np.ndarray],
poly2: List[np.ndarray]) -> None:
for p1, p2 in zip(poly1, poly2):
self.assertTrue(
Polygon(p1.reshape(-1, 2)).equals(Polygon(p2.reshape(-1, 2))))
def assert_result_equal(self,
results: Dict,
poly_targets: List[np.ndarray],
bbox_targets: np.ndarray,
bbox_label_targets: np.ndarray,
ignore_targets: np.ndarray,
text_targets: Optional[List[str]] = None) -> None:
self.assertPolyEqual(poly_targets, results['gt_polygons'])
self.assertTrue(np.array_equal(bbox_targets, results['gt_bboxes']))
self.assertTrue(
np.array_equal(bbox_label_targets, results['gt_bboxes_labels']))
self.assertTrue(np.array_equal(ignore_targets, results['gt_ignored']))
self.assertEqual(text_targets, results['gt_texts'])
self.assertEqual(results['img_shape'],
(results['img'].shape[0], results['img'].shape[1]))
def test_transform(self):
# Test empty transform
imgaug_transform = ImgAugWrapper()
results = self._create_dummy_data()
origin_results = copy.deepcopy(results)
results = imgaug_transform(results)
self.assert_result_equal(results, origin_results['gt_polygons'],
origin_results['gt_bboxes'],
origin_results['gt_bboxes_labels'],
origin_results['gt_ignored'],
origin_results['gt_texts'])
args = [dict(cls='Affine', translate_px=dict(x=-10, y=-10))]
imgaug_transform = ImgAugWrapper(args)
results = self._create_dummy_data()
results = imgaug_transform(results)
# Polygons and bboxes are partially outside the image after
# transformation
poly_target = [
np.array([0, 0, 40, 0, 40, 40, 0, 40]),
np.array([10, 10, 40, 10, 40, 40, 10, 40])
]
box_target = np.array([[0, 0, 40, 40], [10, 10, 40, 40]])
label_target = np.array([0, 1], dtype=np.int64)
ignored = np.array([False, True], dtype=bool)
texts = ['text1', 'text2']
self.assert_result_equal(results, poly_target, box_target,
label_target, ignored, texts)
# Some polygons and bboxes are no longer inside the image after
# transformation
args = [
dict(cls='Affine', translate_px=dict(x=30, y=30)), ['Fliplr', 1]
]
poly_target = [np.array([0, 30, 20, 30, 20, 50, 0, 50])]
box_target = np.array([[0, 30, 20, 50]])
label_target = np.array([0], dtype=np.int64)
ignored = np.array([False], dtype=bool)
texts = ['text1']
imgaug_transform = ImgAugWrapper(args)
results = self._create_dummy_data()
results = imgaug_transform(results)
self.assert_result_equal(results, poly_target, box_target,
label_target, ignored, texts)
# All polygons and bboxes are no longer inside the image after
# transformation
# When some transforms result in empty polygons
args = [dict(cls='Affine', translate_px=dict(x=100, y=100))]
results = self._create_dummy_data()
invalid_transform = ImgAugWrapper(args)
results = invalid_transform(results)
self.assertIsNone(results)
# Everything should work well without gt_texts
results = self._create_dummy_data()
del results['gt_texts']
results = imgaug_transform(results)
self.assertNotIn('gt_texts', results)
# Everything should work well without keys required from text detection
results = imgaug_transform(
dict(
img=np.random.rand(10, 20, 3),
img_shape=(10, 20),
gt_texts=['text1', 'text2']))
self.assertEqual(results['gt_texts'], ['text1', 'text2'])
def test_repr(self):
args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]]
transform = ImgAugWrapper(args)
print(repr(transform))
self.assertEqual(
repr(transform),
("ImgAugWrapper(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"
))
class TestTorchVisionWrapper(unittest.TestCase):
def test_transform(self):
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
# object not found error
with self.assertRaises(Exception):
TorchVisionWrapper(op='NonExist')
with self.assertRaises(TypeError):
TorchVisionWrapper()
f = TorchVisionWrapper('Grayscale')
with self.assertRaises(AssertionError):
f({})
results = f(x)
assert results['img'].shape == (128, 100)
assert results['img_shape'] == (128, 100)
def test_repr(self):
f = TorchVisionWrapper('Grayscale', num_output_channels=3)
self.assertEqual(
repr(f),
'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')