[Fix] Add zero division handler in poly utils, remove Polygon3 (#448)

* Add check to avoid zero div in iou computation

* replace polygon3 with shapely

* remove req of Polygon3
pull/472/head
Tong Gao 2021-08-25 13:14:58 +08:00 committed by GitHub
parent 7c1bf45c63
commit 0881c2d2a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 159 additions and 57 deletions

View File

@ -33,9 +33,9 @@ def compute_recall_precision(gt_polys, pred_polys):
gt = gt_polys[gt_id]
det = pred_polys[pred_id]
inter_area, _ = eval_utils.poly_intersection(det, gt)
gt_area = gt.area()
det_area = det.area()
inter_area = eval_utils.poly_intersection(det, gt)
gt_area = gt.area
det_area = det.area
if gt_area != 0:
recall[gt_id, pred_id] = inter_area / gt_area
if det_area != 0:

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import Polygon as plg
from shapely.geometry import Polygon as plg
import mmocr.utils as utils
@ -44,8 +44,8 @@ def ignore_pred(pred_boxes, gt_ignored_index, gt_polys, precision_thr):
# if its overlap with any ignored gt > precision_thr
for ignored_box_id in gt_ignored_index:
ignored_box = gt_polys[ignored_box_id]
inter_area, _ = poly_intersection(poly, ignored_box)
area = poly.area()
inter_area = poly_intersection(poly, ignored_box)
area = poly.area
precision = 0 if area == 0 else inter_area / area
if precision > precision_thr:
pred_ignored_index.append(box_id)
@ -113,7 +113,7 @@ def box2polygon(box):
[box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]])
point_mat = boundary.reshape([-1, 2])
return plg.Polygon(point_mat)
return plg(point_mat)
def points2polygon(points):
@ -133,53 +133,85 @@ def points2polygon(points):
assert (points.size % 2 == 0) and (points.size >= 8)
point_mat = points.reshape([-1, 2])
return plg.Polygon(point_mat)
return plg(point_mat)
def poly_intersection(poly_det, poly_gt):
def poly_intersection(poly_det, poly_gt, invalid_ret=0, return_poly=False):
"""Calculate the intersection area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
invalid_ret (int|float): The return value when invalid polygon exists.
return_poly (bool): Whether to return the polygon of the intersection
area.
Returns:
intersection_area (float): The intersection area between two polygons.
poly_obj (Polygon, optional): The Polygon object of the intersection
area. Set as `None` if the input is
invalid.
"""
assert isinstance(poly_det, plg.Polygon)
assert isinstance(poly_gt, plg.Polygon)
assert isinstance(poly_det, plg)
assert isinstance(poly_gt, plg)
poly_inter = poly_det & poly_gt
if len(poly_inter) == 0:
return 0, poly_inter
return poly_inter.area(), poly_inter
if poly_det.is_valid and poly_gt.is_valid:
poly_obj = poly_det.intersection(poly_gt)
if return_poly:
return poly_obj.area, poly_obj
else:
return poly_obj.area
else:
if return_poly:
return invalid_ret, None
else:
return invalid_ret
def poly_union(poly_det, poly_gt):
def poly_union(poly_det, poly_gt, invalid_ret=0, return_poly=False):
"""Calculate the union area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
invalid_ret (int|float): The return value when invalid polygon exists.
return_poly (bool): Whether to return the polygon of the intersection
area.
Returns:
union_area (float): The union area between two polygons.
poly_obj (Polygon, optional): The polygon object of the union
area between two polygons. Set as
`None` if the input is invalid.
poly_obj (Polygon|MultiPolygon, optional): The Polygon or MultiPolygon
object of the union of the inputs. The type
of object depends on whether they intersect
or not. Set as `None` if the input is
invalid.
"""
assert isinstance(poly_det, plg.Polygon)
assert isinstance(poly_gt, plg.Polygon)
assert isinstance(poly_det, plg)
assert isinstance(poly_gt, plg)
area_det = poly_det.area()
area_gt = poly_gt.area()
area_inters, _ = poly_intersection(poly_det, poly_gt)
return area_det + area_gt - area_inters
if poly_det.is_valid and poly_gt.is_valid:
poly_obj = poly_det.union(poly_gt)
if return_poly:
return poly_obj.area, poly_obj
else:
return poly_obj.area
else:
if return_poly:
return invalid_ret, None
else:
return invalid_ret
def boundary_iou(src, target):
def boundary_iou(src, target, zero_division=0):
"""Calculate the IOU between two boundaries.
Args:
src (list): Source boundary.
target (list): Target boundary.
zero_division (int|float): The return value when invalid
boundary exists.
Returns:
iou (float): The iou between two boundaries.
@ -189,24 +221,26 @@ def boundary_iou(src, target):
src_poly = points2polygon(src)
target_poly = points2polygon(target)
return poly_iou(src_poly, target_poly)
return poly_iou(src_poly, target_poly, zero_division=zero_division)
def poly_iou(poly_det, poly_gt):
def poly_iou(poly_det, poly_gt, zero_division=0):
"""Calculate the IOU between two polygons.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
zero_division (int|float): The return value when invalid
polygon exists.
Returns:
iou (float): The IOU between two polygons.
"""
assert isinstance(poly_det, plg.Polygon)
assert isinstance(poly_gt, plg.Polygon)
area_inters, _ = poly_intersection(poly_det, poly_gt)
return area_inters / poly_union(poly_det, poly_gt)
assert isinstance(poly_det, plg)
assert isinstance(poly_gt, plg)
area_inters = poly_intersection(poly_det, poly_gt)
area_union = poly_union(poly_det, poly_gt)
return area_inters / area_union if area_union != 0 else zero_division
def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr,

View File

@ -3,9 +3,9 @@ import sys
import cv2
import numpy as np
import Polygon as plg
import pyclipper
from mmcv.utils import print_log
from shapely.geometry import Polygon as plg
import mmocr.utils.check_argument as check_argument
@ -110,7 +110,7 @@ class BaseTextDetTargets:
for text_ind, poly in enumerate(text_polys):
instance = poly[0].reshape(-1, 2).astype(np.int32)
area = plg.Polygon(instance).area()
area = plg(instance).area
peri = cv2.arcLength(instance, True)
distance = min(
int(area * (1 - shrink_ratio * shrink_ratio) / (peri + 0.001) +

View File

@ -4,12 +4,12 @@ import math
import cv2
import mmcv
import numpy as np
import Polygon as plg
import torchvision.transforms as transforms
from mmdet.core import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.transforms import Resize
from PIL import Image
from shapely.geometry import Polygon as plg
import mmocr.core.evaluation.utils as eval_utils
from mmocr.utils import check_argument
@ -91,10 +91,11 @@ class RandomCropInstances:
for idx, bbox in enumerate(bboxes):
poly = eval_utils.box2polygon(bbox)
area, inters = eval_utils.poly_intersection(poly, canvas_poly)
area, inters = eval_utils.poly_intersection(
poly, canvas_poly, return_poly=True)
if area == 0:
continue
xmin, xmax, ymin, ymax = inters.boundingBox()
xmin, ymin, xmax, ymax = inters.bounds
kept_bboxes += [
np.array(
[xmin - tl[0], ymin - tl[1], xmax - tl[0], ymax - tl[1]],
@ -847,28 +848,28 @@ class RandomCropFlip:
pts = np.stack([[xmin, xmax, xmax, xmin],
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
pp = plg.Polygon(pts)
pp = plg(pts)
fail_flag = False
for polygon in polygons:
ppi = plg.Polygon(polygon[0].reshape(-1, 2))
ppiou, _ = eval_utils.poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
ppi = plg(polygon[0].reshape(-1, 2))
ppiou = eval_utils.poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
np.abs(ppiou) > self.epsilon:
fail_flag = True
break
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
polys_new.append(polygon)
else:
polys_keep.append(polygon)
for polygon in ignore_polygons:
ppi = plg.Polygon(polygon[0].reshape(-1, 2))
ppiou, _ = eval_utils.poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
ppi = plg(polygon[0].reshape(-1, 2))
ppiou = eval_utils.poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
np.abs(ppiou) > self.epsilon:
fail_flag = True
break
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
ign_polys_new.append(polygon)
else:
ign_polys_keep.append(polygon)

View File

@ -503,7 +503,7 @@ def poly_nms(polygons, threshold):
for i in range(len(index)):
B = polygons[index[i]][:-1]
iou_list[i] = boundary_iou(A, B)
iou_list[i] = boundary_iou(A, B, 1)
remove_index = np.where(iou_list > threshold)
index = np.delete(index, remove_index)

View File

@ -1,5 +1,4 @@
# These must be installed before building mmocr
numpy
Polygon3
pyclipper
torch>=1.1

View File

@ -5,7 +5,6 @@ lmdb
matplotlib
mmcv
mmdet
Polygon3
pyclipper
rapidfuzz
regex

View File

@ -4,7 +4,6 @@ lmdb
matplotlib
numba>=0.45.1
numpy
Polygon3
pyclipper
rapidfuzz
scikit-image

View File

@ -4,7 +4,6 @@ flake8
isort
# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
kwarray
Polygon3
pytest
pytest-cov
pytest-runner

View File

@ -20,7 +20,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmocr
known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,packaging,pyclipper,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision,yaml
known_third_party = PIL,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,packaging,pyclipper,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -61,6 +61,21 @@ def test_random_crop_instances(mock_randint, mock_sample):
assert np.allclose(np.array([[0, 0], [0, 0], [0, 0]]), crop[0])
assert np.allclose(crop[1], [0, 0, 2, 3])
# test crop_bboxes
canvas_box = np.array([2, 3, 5, 5])
bboxes = np.array([[2, 3, 4, 4], [0, 0, 1, 1], [1, 2, 4, 4],
[0, 0, 10, 10]])
kept_bboxes, kept_idx = rci.crop_bboxes(bboxes, canvas_box)
assert np.allclose(kept_bboxes,
np.array([[0, 0, 2, 1], [0, 0, 2, 1], [0, 0, 3, 2]]))
assert kept_idx == [0, 2, 3]
bboxes = np.array([[10, 10, 11, 11], [0, 0, 1, 1]])
kept_bboxes, kept_idx = rci.crop_bboxes(bboxes, canvas_box)
assert kept_bboxes.size == 0
assert kept_bboxes.shape == (0, 4)
assert len(kept_idx) == 0
# test __call__
rci = transforms.RandomCropInstances(3, instance_key='gt_kernels')
results = {}
@ -71,7 +86,6 @@ def test_random_crop_instances(mock_randint, mock_sample):
mock_sample.side_effect = [0.1]
mock_randint.side_effect = [1, 1]
output = rci(results)
print(output['img'])
target = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]])
assert output['img_shape'] == (3, 3)

View File

@ -2,6 +2,7 @@
"""Tests the utils of evaluation."""
import numpy as np
import pytest
from shapely.geometry import MultiPolygon, Polygon
import mmocr.core.evaluation.utils as utils
@ -85,11 +86,19 @@ def test_points2polygon():
# test np.array
points = np.array([1, 2, 3, 4, 5, 6, 7, 8])
poly = utils.points2polygon(points)
assert poly.nPoints() == 4
i = 0
for coord in poly.exterior.coords[:-1]:
assert coord[0] == points[i]
assert coord[1] == points[i + 1]
i += 2
points = [1, 2, 3, 4, 5, 6, 7, 8]
poly = utils.points2polygon(points)
assert poly.nPoints() == 4
i = 0
for coord in poly.exterior.coords[:-1]:
assert coord[0] == points[i]
assert coord[1] == points[i + 1]
i += 2
def test_poly_intersection():
@ -102,16 +111,34 @@ def test_poly_intersection():
points = [0, 0, 0, 1, 1, 1, 1, 0]
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
points4 = [0.5, 0, 1.5, 0, 1.5, 1, 0.5, 1]
poly = utils.points2polygon(points)
poly1 = utils.points2polygon(points1)
poly2 = utils.points2polygon(points2)
poly3 = utils.points2polygon(points3)
poly4 = utils.points2polygon(points4)
area_inters, _ = utils.poly_intersection(poly, poly1)
area_inters = utils.poly_intersection(poly, poly1)
assert area_inters == 0
# test overlapping polygons
area_inters, _ = utils.poly_intersection(poly, poly)
area_inters = utils.poly_intersection(poly, poly)
assert area_inters == 1
area_inters = utils.poly_intersection(poly, poly4)
assert area_inters == 0.5
# test invalid polygons
assert utils.poly_intersection(poly2, poly2) == 0
assert utils.poly_intersection(poly3, poly3, invalid_ret=1) == 1
# test poly return
_, poly = utils.poly_intersection(poly, poly4, return_poly=True)
assert isinstance(poly, Polygon)
_, poly = utils.poly_intersection(poly2, poly3, return_poly=True)
assert poly is None
def test_poly_union():
@ -124,14 +151,28 @@ def test_poly_union():
points = [0, 0, 0, 1, 1, 1, 1, 0]
points1 = [2, 2, 2, 3, 3, 3, 3, 2]
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
poly = utils.points2polygon(points)
poly1 = utils.points2polygon(points1)
poly2 = utils.points2polygon(points2)
poly3 = utils.points2polygon(points3)
assert utils.poly_union(poly, poly1) == 2
# test overlapping polygons
assert utils.poly_union(poly, poly) == 1
# test invalid polygons
assert utils.poly_union(poly2, poly2) == 0
assert utils.poly_union(poly3, poly3, invalid_ret=1) == 1
# test poly return
_, poly = utils.poly_union(poly, poly1, return_poly=True)
assert isinstance(poly, MultiPolygon)
_, poly = utils.poly_union(poly2, poly3, return_poly=True)
assert poly is None
def test_poly_iou():
@ -141,25 +182,41 @@ def test_poly_iou():
points = [0, 0, 0, 1, 1, 1, 1, 0]
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
poly = utils.points2polygon(points)
poly1 = utils.points2polygon(points1)
poly2 = utils.points2polygon(points2)
poly3 = utils.points2polygon(points3)
assert utils.poly_iou(poly, poly1) == 0
# test overlapping polygons
assert utils.poly_iou(poly, poly) == 1
# test invalid polygons
assert utils.poly_iou(poly2, poly2) == 0
assert utils.poly_iou(poly3, poly3, zero_division=1) == 1
assert utils.poly_iou(poly2, poly3) == 0
def test_boundary_iou():
points = [0, 0, 0, 1, 1, 1, 1, 0]
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
assert utils.boundary_iou(points, points1) == 0
# test overlapping boundaries
assert utils.boundary_iou(points, points) == 1
# test invalid boundaries
assert utils.boundary_iou(points2, points2) == 0
assert utils.boundary_iou(points3, points3, zero_division=1) == 1
assert utils.boundary_iou(points2, points3) == 0
def test_points_center():