mirror of https://github.com/open-mmlab/mmocr.git
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from mmocr.utils import remove_pipeline_elements
|
|
|
|
|
|
class TestTransformUtils(unittest.TestCase):
|
|
|
|
def test_remove_pipeline_elements(self):
|
|
data = dict(img=np.random.random((30, 40, 3)))
|
|
results = remove_pipeline_elements(copy.deepcopy(data), [0, 1, 2])
|
|
self.assertTrue(np.array_equal(results['img'], data['img']))
|
|
self.assertEqual(len(data), len(results))
|
|
|
|
data['gt_polygons'] = [
|
|
np.array([0., 0., 10., 10., 10., 0.]),
|
|
np.array([0., 0., 10., 0., 0., 10.]),
|
|
np.array([0, 10, 0, 10, 1, 2, 3, 4]),
|
|
np.array([0, 10, 0, 10, 10, 0, 0, 10]),
|
|
]
|
|
data['dummy'] = [
|
|
np.array([0., 0., 10., 10., 10., 0.]),
|
|
]
|
|
data['gt_ignored'] = np.array([True, True, False, False], dtype=bool)
|
|
data['gt_bboxes_labels'] = np.array([0, 1, 2, 3])
|
|
data['gt_bboxes'] = np.array([[1, 2, 3, 4], [5, 6, 7, 8],
|
|
[0, 0, 10, 10], [0, 0, 0, 0]])
|
|
data['gt_texts'] = ['t1', 't2', 't3', 't4']
|
|
keys = [
|
|
'gt_polygons', 'gt_bboxes', 'gt_ignored', 'gt_texts',
|
|
'gt_bboxes_labels'
|
|
]
|
|
results = remove_pipeline_elements(copy.deepcopy(data), [0, 1, 2])
|
|
|
|
for key in keys:
|
|
self.assertTrue(np.array_equal(results[key][0], data[key][3]))
|
|
self.assertTrue(np.array_equal(results['img'], data['img']))
|
|
self.assertTrue(np.array_equal(results['dummy'], data['dummy']))
|