diff --git a/mmocr/datasets/pipelines/processing.py b/mmocr/datasets/pipelines/processing.py index 6cfdb165..9a2e1577 100644 --- a/mmocr/datasets/pipelines/processing.py +++ b/mmocr/datasets/pipelines/processing.py @@ -170,7 +170,8 @@ class Resize(MMCV_Resize): if polygon is not None: polygons_resize.append(polygon.astype(np.float32)) else: - polygons_resize.append(np.zeros_like(polygons[idx])) + polygons_resize.append( + np.zeros_like(polygons[idx], dtype=np.float32)) results['gt_polygons'] = polygons_resize def transform(self, results: dict) -> dict: diff --git a/mmocr/utils/polygon_utils.py b/mmocr/utils/polygon_utils.py index e661b580..40368525 100644 --- a/mmocr/utils/polygon_utils.py +++ b/mmocr/utils/polygon_utils.py @@ -100,8 +100,8 @@ def poly2shapely(polygon: ArrayLike) -> Polygon: Returns: polygon (Polygon): A polygon object. """ - assert len(polygon) % 2 == 0 and len(polygon) >= 8 polygon = np.array(polygon, dtype=np.float32) + assert polygon.size % 2 == 0 and polygon.size >= 6 polygon = polygon.reshape([-1, 2]) return Polygon(polygon) @@ -119,7 +119,8 @@ def polys2shapely(polygons: Sequence[ArrayLike]) -> Sequence[Polygon]: return [poly2shapely(polygon) for polygon in polygons] -def crop_polygon(polygon: ArrayLike, crop_box: np.ndarray) -> np.ndarray: +def crop_polygon(polygon: ArrayLike, + crop_box: np.ndarray) -> Union[np.ndarray, None]: """Crop polygon to be within a box region. Args: @@ -127,19 +128,18 @@ def crop_polygon(polygon: ArrayLike, crop_box: np.ndarray) -> np.ndarray: crop_box (ndarray): target box region in shape (4, ). Returns: - np.array or None: Cropped polygon. + np.array or None: Cropped polygon. If the polygon is not within the + crop box, return None. """ - polygon = np.asarray(polygon, dtype=np.float32) - crop_box = np.asarray(crop_box, dtype=np.float32) - poly = Polygon(polygon.reshape(-1, 2)) - crop_poly = Polygon(bbox2poly(crop_box).reshape(-1, 2)) + poly = poly2shapely(polygon) + crop_poly = poly2shapely(bbox2poly(crop_box)) poly_cropped = poly.intersection(crop_poly) if poly_cropped.area == 0.: # If polygon is outside crop_box region, return None. return None else: - poly_cropped = np.array(poly_cropped.boundary.xy)[:, :-1] - return poly_cropped.reshape(-1) + poly_cropped = np.array(poly_cropped.boundary.xy, dtype=np.float32) + return poly_cropped[:, :-1].T.reshape(-1) def poly_make_valid(poly: Polygon) -> Polygon: diff --git a/tests/test_utils/test_polygon_utils.py b/tests/test_utils/test_polygon_utils.py index 06de7b6e..338ac91d 100644 --- a/tests/test_utils/test_polygon_utils.py +++ b/tests/test_utils/test_polygon_utils.py @@ -17,22 +17,25 @@ class TestCropPolygon(unittest.TestCase): # polygon cross box polygon = np.array([20., -10., 40., 10., 10., 40., -10., 20.]) crop_box = np.array([0., 0., 60., 60.]) - target_poly_cropped = np.array([[10., 40., 30., 10., 0., 0., 10.], - [40., 10., 0., 0., 10., 30., 40.]]) + target_poly_cropped = np.array( + [10, 40, 0, 30, 0, 10, 10, 0, 30, 0, 40, 10]) poly_cropped = crop_polygon(polygon, crop_box) - self.assertTrue(target_poly_cropped.all() == poly_cropped.all()) + self.assertTrue( + poly2shapely(poly_cropped).equals( + poly2shapely(target_poly_cropped))) # polygon inside box polygon = np.array([0., 0., 30., 0., 30., 30., 0., 30.]).reshape(-1, 2) crop_box = np.array([0., 0., 60., 60.]) target_poly_cropped = polygon poly_cropped = crop_polygon(polygon, crop_box) - self.assertTrue(target_poly_cropped.all() == poly_cropped.all()) + self.assertTrue( + poly2shapely(poly_cropped).equals( + poly2shapely(target_poly_cropped))) # polygon outside box polygon = np.array([0., 0., 30., 0., 30., 30., 0., 30.]).reshape(-1, 2) crop_box = np.array([80., 80., 90., 90.]) - target_poly_cropped = polygon poly_cropped = crop_polygon(polygon, crop_box) self.assertEqual(poly_cropped, None) @@ -133,10 +136,10 @@ class TestPolygonUtils(unittest.TestCase): self.assertEqual(polys2shapely(polys), polygons) # test invalid polys = [0, 0, 1] - with self.assertRaises(TypeError): + with self.assertRaises(AssertionError): polys2shapely(polys) polys = [0, 0, 1, 0, 1, 1, 0, 1, 1] - with self.assertRaises(TypeError): + with self.assertRaises(AssertionError): polys2shapely(polys) def test_poly_make_valid(self):