mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
139 lines
5.4 KiB
Python
139 lines
5.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from mmdet.structures.mask import PolygonMasks, bitmap_to_polygon
|
|
|
|
from mmocr.datasets import MMDet2MMOCR, MMOCR2MMDet, Resize
|
|
from mmocr.utils import poly2shapely
|
|
|
|
|
|
class TestMMDet2MMOCR(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
img = np.zeros((15, 30, 3))
|
|
img_shape = (15, 30)
|
|
polygons = [
|
|
np.array([10., 5., 20., 5., 20., 10., 10., 10.]),
|
|
np.array([10., 5., 20., 5., 20., 10., 10., 10., 8., 7.])
|
|
]
|
|
ignores = np.array([True, False])
|
|
bboxes = np.array([[10., 5., 20., 10.], [0., 0., 10., 10.]])
|
|
self.data_info_ocr = dict(
|
|
img=img,
|
|
gt_polygons=polygons,
|
|
gt_bboxes=bboxes,
|
|
img_shape=img_shape,
|
|
gt_ignored=ignores)
|
|
|
|
_polygons = [[polygon] for polygon in polygons]
|
|
masks = PolygonMasks(_polygons, *img_shape)
|
|
self.data_info_det_polygon = dict(
|
|
img=img,
|
|
gt_masks=masks,
|
|
gt_bboxes=bboxes,
|
|
gt_ignore_flags=ignores,
|
|
img_shape=img_shape)
|
|
|
|
masks = masks.to_bitmap()
|
|
self.data_info_det_mask = dict(
|
|
img=img,
|
|
gt_masks=masks,
|
|
gt_bboxes=bboxes,
|
|
gt_ignore_flags=ignores,
|
|
img_shape=img_shape)
|
|
|
|
def test_ocr2det_polygonmasks(self):
|
|
transform = MMOCR2MMDet()
|
|
results = transform(self.data_info_ocr.copy())
|
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
|
self.assertEqual(results['img_shape'], (15, 30))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_masks'].masks[0][0],
|
|
self.data_info_det_polygon['gt_masks'].masks[0][0]))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_masks'].masks[0][0],
|
|
self.data_info_det_polygon['gt_masks'].masks[0][0]))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_bboxes'],
|
|
self.data_info_det_polygon['gt_bboxes']))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_ignore_flags'],
|
|
self.data_info_det_polygon['gt_ignore_flags']))
|
|
|
|
def test_ocr2det_bitmapmasks(self):
|
|
transform = MMOCR2MMDet(poly2mask=True)
|
|
results = transform(self.data_info_ocr.copy())
|
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
|
self.assertEqual(results['img_shape'], (15, 30))
|
|
self.assertTrue(
|
|
poly2shapely(
|
|
bitmap_to_polygon(
|
|
results['gt_masks'].masks[0])[0][0].flatten()).equals(
|
|
poly2shapely(
|
|
bitmap_to_polygon(
|
|
self.data_info_det_mask['gt_masks'].masks[0])
|
|
[0][0].flatten())))
|
|
|
|
self.assertTrue(
|
|
np.allclose(results['gt_bboxes'],
|
|
self.data_info_det_mask['gt_bboxes']))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_ignore_flags'],
|
|
self.data_info_det_mask['gt_ignore_flags']))
|
|
|
|
def test_det2ocr_polygonmasks(self):
|
|
transform = MMDet2MMOCR()
|
|
results = transform(self.data_info_det_polygon.copy())
|
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
|
self.assertEqual(results['img_shape'], (15, 30))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_polygons'][0],
|
|
self.data_info_ocr['gt_polygons'][0]))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_polygons'][1],
|
|
self.data_info_ocr['gt_polygons'][1]))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_ignored'],
|
|
self.data_info_ocr['gt_ignored']))
|
|
|
|
def test_det2ocr_bitmapmasks(self):
|
|
transform = MMDet2MMOCR()
|
|
results = transform(self.data_info_det_mask.copy())
|
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
|
self.assertEqual(results['img_shape'], (15, 30))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_ignored'],
|
|
self.data_info_ocr['gt_ignored']))
|
|
|
|
def test_ocr2det2ocr(self):
|
|
from mmdet.datasets.transforms import Resize as MMDet_Resize
|
|
t1 = MMOCR2MMDet()
|
|
t2 = MMDet_Resize(scale=(60, 60))
|
|
t3 = MMDet2MMOCR()
|
|
t4 = Resize(scale=(30, 15))
|
|
results = t4(t3(t2(t1(self.data_info_ocr.copy()))))
|
|
self.assertEqual(results['img'].shape, (15, 30, 3))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_polygons'][0],
|
|
self.data_info_ocr['gt_polygons'][0]))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_polygons'][1],
|
|
self.data_info_ocr['gt_polygons'][1]))
|
|
self.assertTrue(
|
|
np.allclose(results['gt_bboxes'], self.data_info_ocr['gt_bboxes']))
|
|
self.assertEqual(results['gt_ignored'].all(),
|
|
self.data_info_ocr['gt_ignored'].all())
|
|
|
|
def test_repr_det2ocr(self):
|
|
transform = MMDet2MMOCR()
|
|
self.assertEqual(repr(transform), ('MMDet2MMOCR'))
|
|
|
|
def test_repr_ocr2det(self):
|
|
transform = MMOCR2MMDet(poly2mask=True)
|
|
self.assertEqual(repr(transform), ('MMOCR2MMDet(poly2mask = True)'))
|