mirror of https://github.com/open-mmlab/mmocr.git
300 lines
11 KiB
Python
300 lines
11 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from shapely.geometry import MultiPolygon, Polygon
|
|
|
|
from mmocr.utils import (boundary_iou, crop_polygon, offset_polygon, poly2bbox,
|
|
poly2shapely, poly_intersection, poly_iou,
|
|
poly_make_valid, poly_union, polys2shapely,
|
|
rescale_polygon, rescale_polygons)
|
|
|
|
|
|
class TestCropPolygon(unittest.TestCase):
|
|
|
|
def test_crop_polygon(self):
|
|
# polygon cross box
|
|
polygon = np.array([20., -10., 40., 10., 10., 40., -10., 20.])
|
|
crop_box = np.array([0., 0., 60., 60.])
|
|
target_poly_cropped = np.array(
|
|
[10, 40, 0, 30, 0, 10, 10, 0, 30, 0, 40, 10])
|
|
poly_cropped = crop_polygon(polygon, crop_box)
|
|
self.assertTrue(
|
|
poly2shapely(poly_cropped).equals(
|
|
poly2shapely(target_poly_cropped)))
|
|
|
|
# polygon inside box
|
|
polygon = np.array([0., 0., 30., 0., 30., 30., 0., 30.])
|
|
crop_box = np.array([0., 0., 60., 60.])
|
|
target_poly_cropped = polygon
|
|
poly_cropped = crop_polygon(polygon, crop_box)
|
|
self.assertTrue(
|
|
poly2shapely(poly_cropped).equals(
|
|
poly2shapely(target_poly_cropped)))
|
|
|
|
# polygon outside box
|
|
polygon = np.array([0., 0., 30., 0., 30., 30., 0., 30.])
|
|
crop_box = np.array([80., 80., 90., 90.])
|
|
poly_cropped = crop_polygon(polygon, crop_box)
|
|
self.assertEqual(poly_cropped, None)
|
|
|
|
|
|
class TestPolygonUtils(unittest.TestCase):
|
|
|
|
def test_rescale_polygon(self):
|
|
scale_factor = (0.3, 0.4)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
polygons = [0, 0, 1, 0, 1, 1, 0]
|
|
rescale_polygon(polygons, scale_factor)
|
|
|
|
polygons = [0, 0, 1, 0, 1, 1, 0, 1]
|
|
self.assertTrue(
|
|
np.allclose(
|
|
rescale_polygon(polygons, scale_factor, mode='div'),
|
|
np.array([0, 0, 1 / 0.3, 0, 1 / 0.3, 1 / 0.4, 0, 1 / 0.4])))
|
|
self.assertTrue(
|
|
np.allclose(
|
|
rescale_polygon(polygons, scale_factor, mode='mul'),
|
|
np.array([0, 0, 0.3, 0, 0.3, 0.4, 0, 0.4])))
|
|
|
|
def test_rescale_polygons(self):
|
|
polygons = [
|
|
np.array([0, 0, 1, 0, 1, 1, 0, 1]),
|
|
np.array([1, 1, 2, 1, 2, 2, 1, 2])
|
|
]
|
|
scale_factor = (0.5, 0.5)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
rescale_polygons(polygons, scale_factor, mode='div'), [
|
|
np.array([0, 0, 2, 0, 2, 2, 0, 2]),
|
|
np.array([2, 2, 4, 2, 4, 4, 2, 4])
|
|
]))
|
|
self.assertTrue(
|
|
np.allclose(
|
|
rescale_polygons(polygons, scale_factor, mode='mul'), [
|
|
np.array([0, 0, 0.5, 0, 0.5, 0.5, 0, 0.5]),
|
|
np.array([0.5, 0.5, 1, 0.5, 1, 1, 0.5, 1])
|
|
]))
|
|
|
|
polygons = [torch.Tensor([0, 0, 1, 0, 1, 1, 0, 1])]
|
|
scale_factor = (0.3, 0.4)
|
|
self.assertTrue(
|
|
np.allclose(
|
|
rescale_polygons(polygons, scale_factor, mode='div'),
|
|
[np.array([0, 0, 1 / 0.3, 0, 1 / 0.3, 1 / 0.4, 0, 1 / 0.4])]))
|
|
self.assertTrue(
|
|
np.allclose(
|
|
rescale_polygons(polygons, scale_factor, mode='mul'),
|
|
[np.array([0, 0, 0.3, 0, 0.3, 0.4, 0, 0.4])]))
|
|
|
|
def test_poly2bbox(self):
|
|
# test np.array
|
|
polygon = np.array([0, 0, 1, 0, 1, 1, 0, 1])
|
|
self.assertTrue(np.all(poly2bbox(polygon) == np.array([0, 0, 1, 1])))
|
|
# test list
|
|
polygon = [0, 0, 1, 0, 1, 1, 0, 1]
|
|
self.assertTrue(np.all(poly2bbox(polygon) == np.array([0, 0, 1, 1])))
|
|
# test tensor
|
|
polygon = torch.Tensor([0, 0, 1, 0, 1, 1, 0, 1])
|
|
self.assertTrue(np.all(poly2bbox(polygon) == np.array([0, 0, 1, 1])))
|
|
|
|
def test_poly2shapely(self):
|
|
polygon = Polygon([[0, 0], [1, 0], [1, 1], [0, 1]])
|
|
# test np.array
|
|
poly = np.array([0, 0, 1, 0, 1, 1, 0, 1])
|
|
self.assertEqual(poly2shapely(poly), polygon)
|
|
# test list
|
|
poly = [0, 0, 1, 0, 1, 1, 0, 1]
|
|
self.assertEqual(poly2shapely(poly), polygon)
|
|
# test tensor
|
|
poly = torch.Tensor([0, 0, 1, 0, 1, 1, 0, 1])
|
|
self.assertEqual(poly2shapely(poly), polygon)
|
|
# test invalid
|
|
poly = [0, 0, 1]
|
|
with self.assertRaises(AssertionError):
|
|
poly2shapely(poly)
|
|
poly = [0, 0, 1, 0, 1, 1, 0, 1, 1]
|
|
with self.assertRaises(AssertionError):
|
|
poly2shapely(poly)
|
|
|
|
def test_polys2shapely(self):
|
|
polygons = [
|
|
Polygon([[0, 0], [1, 0], [1, 1], [0, 1]]),
|
|
Polygon([[1, 0], [1, 1], [0, 1], [0, 0]])
|
|
]
|
|
# test np.array
|
|
polys = np.array([[0, 0, 1, 0, 1, 1, 0, 1], [1, 0, 1, 1, 0, 1, 0, 0]])
|
|
self.assertEqual(polys2shapely(polys), polygons)
|
|
# test list
|
|
polys = [[0, 0, 1, 0, 1, 1, 0, 1], [1, 0, 1, 1, 0, 1, 0, 0]]
|
|
self.assertEqual(polys2shapely(polys), polygons)
|
|
# test tensor
|
|
polys = torch.Tensor([[0, 0, 1, 0, 1, 1, 0, 1],
|
|
[1, 0, 1, 1, 0, 1, 0, 0]])
|
|
self.assertEqual(polys2shapely(polys), polygons)
|
|
# test invalid
|
|
polys = [0, 0, 1]
|
|
with self.assertRaises(AssertionError):
|
|
polys2shapely(polys)
|
|
polys = [0, 0, 1, 0, 1, 1, 0, 1, 1]
|
|
with self.assertRaises(AssertionError):
|
|
polys2shapely(polys)
|
|
|
|
def test_poly_make_valid(self):
|
|
poly = Polygon([[0, 0], [1, 1], [1, 0], [0, 1]])
|
|
self.assertFalse(poly.is_valid)
|
|
poly = poly_make_valid(poly)
|
|
self.assertTrue(poly.is_valid)
|
|
# invalid input
|
|
with self.assertRaises(AssertionError):
|
|
poly_make_valid([0, 0, 1, 1, 1, 0, 0, 1])
|
|
|
|
def test_poly_intersection(self):
|
|
|
|
# test unsupported type
|
|
with self.assertRaises(AssertionError):
|
|
poly_intersection(0, 1)
|
|
|
|
# test non-overlapping polygons
|
|
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
|
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
|
|
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
|
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
|
points4 = [0.5, 0, 1.5, 0, 1.5, 1, 0.5, 1]
|
|
poly = poly2shapely(points)
|
|
poly1 = poly2shapely(points1)
|
|
poly2 = poly2shapely(points2)
|
|
poly3 = poly2shapely(points3)
|
|
poly4 = poly2shapely(points4)
|
|
|
|
area_inters = poly_intersection(poly, poly1)
|
|
self.assertEqual(area_inters, 0.)
|
|
|
|
# test overlapping polygons
|
|
area_inters = poly_intersection(poly, poly)
|
|
self.assertEqual(area_inters, 1)
|
|
area_inters = poly_intersection(poly, poly4)
|
|
self.assertEqual(area_inters, 0.5)
|
|
|
|
# test invalid polygons
|
|
self.assertEqual(poly_intersection(poly2, poly2), 0)
|
|
self.assertEqual(poly_intersection(poly3, poly3, invalid_ret=1), 1)
|
|
self.assertEqual(
|
|
poly_intersection(poly3, poly3, invalid_ret=None), 0.25)
|
|
|
|
# test poly return
|
|
_, poly = poly_intersection(poly, poly4, return_poly=True)
|
|
self.assertTrue(isinstance(poly, Polygon))
|
|
_, poly = poly_intersection(
|
|
poly3, poly3, invalid_ret=None, return_poly=True)
|
|
self.assertTrue(isinstance(poly, Polygon))
|
|
_, poly = poly_intersection(
|
|
poly2, poly3, invalid_ret=1, return_poly=True)
|
|
self.assertTrue(poly is None)
|
|
|
|
def test_poly_union(self):
|
|
|
|
# test unsupported type
|
|
with self.assertRaises(AssertionError):
|
|
poly_union(0, 1)
|
|
|
|
# test non-overlapping polygons
|
|
|
|
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
|
points1 = [2, 2, 2, 3, 3, 3, 3, 2]
|
|
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
|
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
|
points4 = [0.5, 0.5, 1, 0, 1, 1, 0.5, 0.5]
|
|
poly = poly2shapely(points)
|
|
poly1 = poly2shapely(points1)
|
|
poly2 = poly2shapely(points2)
|
|
poly3 = poly2shapely(points3)
|
|
poly4 = poly2shapely(points4)
|
|
|
|
assert poly_union(poly, poly1) == 2
|
|
|
|
# test overlapping polygons
|
|
assert poly_union(poly, poly) == 1
|
|
|
|
# test invalid polygons
|
|
self.assertEqual(poly_union(poly2, poly2), 0)
|
|
self.assertEqual(poly_union(poly3, poly3, invalid_ret=1), 1)
|
|
|
|
# The return value depends on the implementation of the package
|
|
self.assertEqual(poly_union(poly3, poly3, invalid_ret=None), 0.25)
|
|
self.assertEqual(poly_union(poly2, poly3), 0.25)
|
|
self.assertEqual(poly_union(poly3, poly4), 0.5)
|
|
|
|
# test poly return
|
|
_, poly = poly_union(poly, poly1, return_poly=True)
|
|
self.assertTrue(isinstance(poly, MultiPolygon))
|
|
_, poly = poly_union(poly3, poly3, return_poly=True)
|
|
self.assertTrue(isinstance(poly, Polygon))
|
|
_, poly = poly_union(poly2, poly3, invalid_ret=0, return_poly=True)
|
|
self.assertTrue(poly is None)
|
|
|
|
def test_poly_iou(self):
|
|
# test unsupported type
|
|
with self.assertRaises(AssertionError):
|
|
poly_iou([1], [2])
|
|
|
|
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
|
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
|
|
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
|
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
|
|
|
poly = poly2shapely(points)
|
|
poly1 = poly2shapely(points1)
|
|
poly2 = poly2shapely(points2)
|
|
poly3 = poly2shapely(points3)
|
|
|
|
self.assertEqual(poly_iou(poly, poly1), 0)
|
|
|
|
# test overlapping polygons
|
|
self.assertEqual(poly_iou(poly, poly), 1)
|
|
|
|
# test invalid polygons
|
|
self.assertEqual(poly_iou(poly2, poly2), 0)
|
|
self.assertEqual(poly_iou(poly3, poly3, zero_division=1), 1)
|
|
self.assertEqual(poly_iou(poly2, poly3), 0)
|
|
|
|
def test_offset_polygon(self):
|
|
# usual case
|
|
polygons = np.array([0, 0, 0, 1, 1, 1, 1, 0], dtype=np.float32)
|
|
expanded_polygon = offset_polygon(polygons, 1)
|
|
self.assertTrue(
|
|
poly2shapely(expanded_polygon).equals(
|
|
poly2shapely(
|
|
np.array(
|
|
[2, 0, 2, 1, 1, 2, 0, 2, -1, 1, -1, 0, 0, -1, 1,
|
|
-1]))))
|
|
|
|
# Overshrunk polygon doesn't exist
|
|
shrunk_polygon = offset_polygon(polygons, -10)
|
|
self.assertEqual(len(shrunk_polygon), 0)
|
|
|
|
# When polygon is shrunk into two polygons, it is regarded as invalid
|
|
# and an empty array is returned.
|
|
polygons = np.array([0, 0, 0, 3, 1, 2, 2, 3, 2, 0, 1, 1],
|
|
dtype=np.float32)
|
|
shrunk = offset_polygon(polygons, -1)
|
|
self.assertEqual(len(shrunk), 0)
|
|
|
|
def test_boundary_iou(self):
|
|
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
|
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
|
|
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
|
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
|
|
|
self.assertEqual(boundary_iou(points, points1), 0)
|
|
|
|
# test overlapping boundaries
|
|
self.assertEqual(boundary_iou(points, points), 1)
|
|
|
|
# test invalid boundaries
|
|
self.assertEqual(boundary_iou(points2, points2), 0)
|
|
self.assertEqual(boundary_iou(points3, points3, zero_division=1), 1)
|
|
self.assertEqual(boundary_iou(points2, points3), 0)
|