[Utils] Migrate core/evaluation/utils.py

pull/1178/head
wangxinyu 2022-07-04 06:37:12 +00:00 committed by gaotongxiao
parent 298ea312c0
commit ef98df8052
17 changed files with 341 additions and 577 deletions

View File

@ -22,13 +22,18 @@ mmocr/models/textdet/detectors/single_stage_text_detector.py
# It will be removed after all utils are moved to mmocr.utils
mmocr/core/evaluation/utils.py
mmocr/models/textdet/postprocessors/utils.py
# It will be removed after all models have been refactored
mmocr/utils/api_utils.py
# It will be removed after all models have been refactored
mmocr/utils/ocr.py
mmocr/core/evaluation/hmean.py
mmocr/core/evaluation/hmean_ic13.py
mmocr/core/evaluation/hmean_iou.py
mmocr/utils/evaluation_utils.py
mmocr/utils/point_utils.py
mmocr/apis/inference.py
# It will be deleted

View File

@ -8,9 +8,8 @@ from mmcv.utils import print_log
import mmocr.utils as utils
from mmocr.core.evaluation import hmean_ic13, hmean_iou
from mmocr.core.evaluation.utils import (filter_2dlist_result,
select_top_boundary)
from mmocr.core.mask import extract_boundary
from mmocr.utils import filter_2dlist_result, select_top_boundary
def output_ranklist(img_results, img_infos, out_file):

View File

@ -2,7 +2,6 @@
import numpy as np
import mmocr.utils as utils
from . import utils as eval_utils
def compute_recall_precision(gt_polys, pred_polys):
@ -33,7 +32,7 @@ def compute_recall_precision(gt_polys, pred_polys):
gt = gt_polys[gt_id]
det = pred_polys[pred_id]
inter_area = eval_utils.poly_intersection(det, gt)
inter_area = utils.poly_intersection(det, gt)
gt_area = gt.area
det_area = det.area
if gt_area != 0:
@ -111,11 +110,11 @@ def eval_hmean_ic13(det_boxes,
accum_precision = 0.
gt_points = gt + gt_ignored
gt_polys = [eval_utils.points2polygon(p) for p in gt_points]
gt_polys = [utils.poly2shapely(p) for p in gt_points]
gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
gt_num = len(gt_polys)
pred_polys, pred_points, pred_ignored_index = eval_utils.ignore_pred(
pred_polys, pred_points, pred_ignored_index = utils.ignore_pred(
pred, gt_ignored_index, gt_polys, precision_thr)
if pred_num > 0 and gt_num > 0:
@ -135,18 +134,18 @@ def eval_hmean_ic13(det_boxes,
or gt_id in gt_ignored_index
or pred_id in pred_ignored_index):
continue
match = eval_utils.one2one_match_ic13(
gt_id, pred_id, recall_mat, precision_mat, recall_thr,
precision_thr)
match = utils.one2one_match_ic13(gt_id, pred_id,
recall_mat, precision_mat,
recall_thr, precision_thr)
if match:
gt_point = np.array(gt_points[gt_id])
det_point = np.array(pred_points[pred_id])
norm_dist = eval_utils.box_center_distance(
norm_dist = utils.box_center_distance(
det_point, gt_point)
norm_dist /= eval_utils.box_diag(
det_point) + eval_utils.box_diag(gt_point)
norm_dist /= utils.box_diag(
det_point) + utils.box_diag(gt_point)
norm_dist *= 2.0
if norm_dist < center_dist_thr:
@ -159,7 +158,7 @@ def eval_hmean_ic13(det_boxes,
for gt_id in range(gt_num):
if gt_id in gt_ignored_index:
continue
match, match_det_set = eval_utils.one2many_match_ic13(
match, match_det_set = utils.one2many_match_ic13(
gt_id, recall_mat, precision_mat, recall_thr,
precision_thr, gt_hit, pred_hit, pred_ignored_index)
@ -177,7 +176,7 @@ def eval_hmean_ic13(det_boxes,
if pred_id in pred_ignored_index:
continue
match, match_gt_set = eval_utils.many2one_match_ic13(
match, match_gt_set = utils.many2one_match_ic13(
pred_id, recall_mat, precision_mat, recall_thr,
precision_thr, gt_hit, pred_hit, gt_ignored_index)
@ -191,8 +190,8 @@ def eval_hmean_ic13(det_boxes,
gt_care_number = gt_num - ignored_num
pred_care_number = pred_num - len(pred_ignored_index)
r, p, h = eval_utils.compute_hmean(accum_recall, accum_precision,
gt_care_number, pred_care_number)
r, p, h = utils.compute_hmean(accum_recall, accum_precision,
gt_care_number, pred_care_number)
img_results.append({'recall': r, 'precision': p, 'hmean': h})
@ -201,8 +200,10 @@ def eval_hmean_ic13(det_boxes,
dataset_hit_recall += accum_recall
dataset_hit_prec += accum_precision
total_r, total_p, total_h = eval_utils.compute_hmean(
dataset_hit_recall, dataset_hit_prec, dataset_gt_num, dataset_pred_num)
total_r, total_p, total_h = utils.compute_hmean(dataset_hit_recall,
dataset_hit_prec,
dataset_gt_num,
dataset_pred_num)
dataset_results = {
'num_gts': dataset_gt_num,

View File

@ -2,7 +2,6 @@
import numpy as np
import mmocr.utils as utils
from . import utils as eval_utils
def eval_hmean_iou(pred_boxes,
@ -57,10 +56,10 @@ def eval_hmean_iou(pred_boxes,
# get gt polygons.
gt_all = gt + gt_ignored
gt_polys = [eval_utils.points2polygon(p) for p in gt_all]
gt_polys = [utils.poly2shapely(p) for p in gt_all]
gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
gt_num = len(gt_polys)
pred_polys, _, pred_ignored_index = eval_utils.ignore_pred(
pred_polys, _, pred_ignored_index = utils.ignore_pred(
pred, gt_ignored_index, gt_polys, precision_thr)
# match.
@ -76,8 +75,7 @@ def eval_hmean_iou(pred_boxes,
gt_pol = gt_polys[gt_id]
det_pol = pred_polys[pred_id]
iou_mat[gt_id,
pred_id] = eval_utils.poly_iou(det_pol, gt_pol)
iou_mat[gt_id, pred_id] = utils.poly_iou(det_pol, gt_pol)
for gt_id in range(gt_num):
for pred_id in range(pred_num):
@ -93,8 +91,8 @@ def eval_hmean_iou(pred_boxes,
gt_care_number = gt_num - gt_ignored_num
pred_care_number = pred_num - len(pred_ignored_index)
r, p, h = eval_utils.compute_hmean(hit_num, hit_num, gt_care_number,
pred_care_number)
r, p, h = utils.compute_hmean(hit_num, hit_num, gt_care_number,
pred_care_number)
img_results.append({'recall': r, 'precision': p, 'hmean': h})
@ -102,7 +100,7 @@ def eval_hmean_iou(pred_boxes,
dataset_gt_num += gt_care_number
dataset_pred_num += pred_care_number
dataset_r, dataset_p, dataset_h = eval_utils.compute_hmean(
dataset_r, dataset_p, dataset_h = utils.compute_hmean(
dataset_hit_num, dataset_hit_num, dataset_gt_num, dataset_pred_num)
dataset_results = {

View File

@ -1,504 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from shapely.geometry import Polygon as plg
import mmocr.utils as utils
def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num):
"""Compute hmean given hit number, ground truth number and prediction
number.
Args:
accum_hit_recall (int|float): Accumulated hits for computing recall.
accum_hit_prec (int|float): Accumulated hits for computing precision.
gt_num (int): Ground truth number.
pred_num (int): Prediction number.
Returns:
recall (float): The recall value.
precision (float): The precision value.
hmean (float): The hmean value.
"""
assert isinstance(accum_hit_recall, (float, int))
assert isinstance(accum_hit_prec, (float, int))
assert isinstance(gt_num, int)
assert isinstance(pred_num, int)
assert accum_hit_recall >= 0.0
assert accum_hit_prec >= 0.0
assert gt_num >= 0.0
assert pred_num >= 0.0
if gt_num == 0:
recall = 1.0
precision = 0.0 if pred_num > 0 else 1.0
else:
recall = float(accum_hit_recall) / gt_num
precision = 0.0 if pred_num == 0 else float(accum_hit_prec) / pred_num
denom = recall + precision
hmean = 0.0 if denom == 0 else (2.0 * precision * recall / denom)
return recall, precision, hmean
def box2polygon(box):
# TODO This has been moved to mmocr.utils. Delete this later
"""Convert box to polygon.
Args:
box (ndarray or list): A ndarray or a list of shape (4)
that indicates 2 points.
Returns:
polygon (Polygon): A polygon object.
"""
if isinstance(box, list):
box = np.array(box)
assert isinstance(box, np.ndarray)
assert box.size == 4
boundary = np.array(
[box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]])
point_mat = boundary.reshape([-1, 2])
return plg(point_mat)
def points2polygon(points):
# TODO This has been moved to mmocr.utils. Delete this later
"""Convert k points to 1 polygon.
Args:
points (ndarray or list): A ndarray or a list of shape (2k)
that indicates k points.
Returns:
polygon (Polygon): A polygon object.
"""
if isinstance(points, list):
points = np.array(points)
assert isinstance(points, np.ndarray)
assert (points.size % 2 == 0) and (points.size >= 8)
point_mat = points.reshape([-1, 2])
return plg(point_mat)
def poly_make_valid(poly):
# TODO This has been moved to mmocr.utils. Delete this later
"""Convert a potentially invalid polygon to a valid one by eliminating
self-crossing or self-touching parts.
Args:
poly (Polygon): A polygon needed to be converted.
Returns:
A valid polygon.
"""
return poly if poly.is_valid else poly.buffer(0)
def poly_intersection(poly_det, poly_gt, invalid_ret=None, return_poly=False):
# TODO This has been moved to mmocr.utils. Delete this later
"""Calculate the intersection area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
invalid_ret (None|float|int): The return value when the invalid polygon
exists. If it is not specified, the function allows the computation
to proceed with invalid polygons by cleaning the their
self-touching or self-crossing parts.
return_poly (bool): Whether to return the polygon of the intersection
area.
Returns:
intersection_area (float): The intersection area between two polygons.
poly_obj (Polygon, optional): The Polygon object of the intersection
area. Set as `None` if the input is invalid.
"""
assert isinstance(poly_det, plg)
assert isinstance(poly_gt, plg)
assert invalid_ret is None or isinstance(invalid_ret, float) or \
isinstance(invalid_ret, int)
if invalid_ret is None:
poly_det = poly_make_valid(poly_det)
poly_gt = poly_make_valid(poly_gt)
poly_obj = None
area = invalid_ret
if poly_det.is_valid and poly_gt.is_valid:
poly_obj = poly_det.intersection(poly_gt)
area = poly_obj.area
return (area, poly_obj) if return_poly else area
def poly_union(poly_det, poly_gt, invalid_ret=None, return_poly=False):
# TODO This has been moved to mmocr.utils. Delete this later
"""Calculate the union area between two polygon.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
invalid_ret (None|float|int): The return value when the invalid polygon
exists. If it is not specified, the function allows the computation
to proceed with invalid polygons by cleaning the their
self-touching or self-crossing parts.
return_poly (bool): Whether to return the polygon of the intersection
area.
Returns:
union_area (float): The union area between two polygons.
poly_obj (Polygon|MultiPolygon, optional): The Polygon or MultiPolygon
object of the union of the inputs. The type of object depends on
whether they intersect or not. Set as `None` if the input is
invalid.
"""
assert isinstance(poly_det, plg)
assert isinstance(poly_gt, plg)
assert invalid_ret is None or isinstance(invalid_ret, float) or \
isinstance(invalid_ret, int)
if invalid_ret is None:
poly_det = poly_make_valid(poly_det)
poly_gt = poly_make_valid(poly_gt)
poly_obj = None
area = invalid_ret
if poly_det.is_valid and poly_gt.is_valid:
poly_obj = poly_det.union(poly_gt)
area = poly_obj.area
return (area, poly_obj) if return_poly else area
def boundary_iou(src, target, zero_division=0):
"""Calculate the IOU between two boundaries.
Args:
src (list): Source boundary.
target (list): Target boundary.
zero_division (int|float): The return value when invalid
boundary exists.
Returns:
iou (float): The iou between two boundaries.
"""
assert utils.valid_boundary(src, False)
assert utils.valid_boundary(target, False)
src_poly = points2polygon(src)
target_poly = points2polygon(target)
return poly_iou(src_poly, target_poly, zero_division=zero_division)
def poly_iou(poly_det, poly_gt, zero_division=0):
# TODO This has been moved to mmocr.utils. Delete this later
"""Calculate the IOU between two polygons.
Args:
poly_det (Polygon): A polygon predicted by detector.
poly_gt (Polygon): A gt polygon.
zero_division (int|float): The return value when invalid
polygon exists.
Returns:
iou (float): The IOU between two polygons.
"""
assert isinstance(poly_det, plg)
assert isinstance(poly_gt, plg)
area_inters = poly_intersection(poly_det, poly_gt)
area_union = poly_union(poly_det, poly_gt)
return area_inters / area_union if area_union != 0 else zero_division
def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr,
precision_thr):
"""One-to-One match gt and det with icdar2013 standards.
Args:
gt_id (int): The ground truth id index.
det_id (int): The detection result id index.
recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the recall ratio of gt i to det j.
precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the precision ratio of gt i to det j.
recall_thr (float): The recall threshold.
precision_thr (float): The precision threshold.
Returns:
True|False: Whether the gt and det are matched.
"""
assert isinstance(gt_id, int)
assert isinstance(det_id, int)
assert isinstance(recall_mat, np.ndarray)
assert isinstance(precision_mat, np.ndarray)
assert 0 <= recall_thr <= 1
assert 0 <= precision_thr <= 1
cont = 0
for i in range(recall_mat.shape[1]):
if recall_mat[gt_id,
i] > recall_thr and precision_mat[gt_id,
i] > precision_thr:
cont += 1
if cont != 1:
return False
cont = 0
for i in range(recall_mat.shape[0]):
if recall_mat[i, det_id] > recall_thr and precision_mat[
i, det_id] > precision_thr:
cont += 1
if cont != 1:
return False
if recall_mat[gt_id, det_id] > recall_thr and precision_mat[
gt_id, det_id] > precision_thr:
return True
return False
def one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr,
precision_thr, gt_match_flag, det_match_flag,
det_ignored_index):
"""One-to-Many match gt and detections with icdar2013 standards.
Args:
gt_id (int): gt index.
recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the recall ratio of gt i to det j.
precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the precision ratio of gt i to det j.
recall_thr (float): The recall threshold.
precision_thr (float): The precision threshold.
gt_match_flag (ndarray): An array indicates each gt matched already.
det_match_flag (ndarray): An array indicates each box has been
matched already or not.
det_ignored_index (list): A list indicates each detection box can be
ignored or not.
Returns:
tuple (True|False, list): The first indicates the gt is matched or not;
the second is the matched detection ids.
"""
assert isinstance(gt_id, int)
assert isinstance(recall_mat, np.ndarray)
assert isinstance(precision_mat, np.ndarray)
assert 0 <= recall_thr <= 1
assert 0 <= precision_thr <= 1
assert isinstance(gt_match_flag, list)
assert isinstance(det_match_flag, list)
assert isinstance(det_ignored_index, list)
many_sum = 0.
det_ids = []
for det_id in range(recall_mat.shape[1]):
if gt_match_flag[gt_id] == 0 and det_match_flag[
det_id] == 0 and det_id not in det_ignored_index:
if precision_mat[gt_id, det_id] >= precision_thr:
many_sum += recall_mat[gt_id, det_id]
det_ids.append(det_id)
if many_sum >= recall_thr:
return True, det_ids
return False, []
def many2one_match_ic13(det_id, recall_mat, precision_mat, recall_thr,
precision_thr, gt_match_flag, det_match_flag,
gt_ignored_index):
"""Many-to-One match gt and detections with icdar2013 standards.
Args:
det_id (int): Detection index.
recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the recall ratio of gt i to det j.
precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the precision ratio of gt i to det j.
recall_thr (float): The recall threshold.
precision_thr (float): The precision threshold.
gt_match_flag (ndarray): An array indicates each gt has been matched
already.
det_match_flag (ndarray): An array indicates each detection box has
been matched already or not.
gt_ignored_index (list): A list indicates each gt box can be ignored
or not.
Returns:
tuple (True|False, list): The first indicates the detection is matched
or not; the second is the matched gt ids.
"""
assert isinstance(det_id, int)
assert isinstance(recall_mat, np.ndarray)
assert isinstance(precision_mat, np.ndarray)
assert 0 <= recall_thr <= 1
assert 0 <= precision_thr <= 1
assert isinstance(gt_match_flag, list)
assert isinstance(det_match_flag, list)
assert isinstance(gt_ignored_index, list)
many_sum = 0.
gt_ids = []
for gt_id in range(recall_mat.shape[0]):
if gt_match_flag[gt_id] == 0 and det_match_flag[
det_id] == 0 and gt_id not in gt_ignored_index:
if recall_mat[gt_id, det_id] >= recall_thr:
many_sum += precision_mat[gt_id, det_id]
gt_ids.append(gt_id)
if many_sum >= precision_thr:
return True, gt_ids
return False, []
def points_center(points):
assert isinstance(points, np.ndarray)
assert points.size % 2 == 0
points = points.reshape([-1, 2])
return np.mean(points, axis=0)
def point_distance(p1, p2):
assert isinstance(p1, np.ndarray)
assert isinstance(p2, np.ndarray)
assert p1.size == 2
assert p2.size == 2
dist = np.square(p2 - p1)
dist = np.sum(dist)
dist = np.sqrt(dist)
return dist
def box_center_distance(b1, b2):
assert isinstance(b1, np.ndarray)
assert isinstance(b2, np.ndarray)
return point_distance(points_center(b1), points_center(b2))
def box_diag(box):
assert isinstance(box, np.ndarray)
assert box.size == 8
return point_distance(box[0:2], box[4:6])
def filter_2dlist_result(results, scores, score_thr):
"""Find out detected results whose score > score_thr.
Args:
results (list[list[float]]): The result list.
score (list): The score list.
score_thr (float): The score threshold.
Returns:
valid_results (list[list[float]]): The valid results.
valid_score (list[float]): The scores which correspond to the valid
results.
"""
assert isinstance(results, list)
assert len(results) == len(scores)
assert isinstance(score_thr, float)
assert 0 <= score_thr <= 1
inds = np.array(scores) > score_thr
valid_results = [results[idx] for idx in np.where(inds)[0].tolist()]
valid_scores = [scores[idx] for idx in np.where(inds)[0].tolist()]
return valid_results, valid_scores
def filter_result(results, scores, score_thr):
"""Find out detected results whose score > score_thr.
Args:
results (ndarray): The results matrix of shape (n, k).
score (ndarray): The score vector of shape (n,).
score_thr (float): The score threshold.
Returns:
valid_results (ndarray): The valid results of shape (m,k) with m<=n.
valid_score (ndarray): The scores which correspond to the
valid results.
"""
assert results.ndim == 2
assert scores.shape[0] == results.shape[0]
assert isinstance(score_thr, float)
assert 0 <= score_thr <= 1
inds = scores > score_thr
valid_results = results[inds, :]
valid_scores = scores[inds]
return valid_results, valid_scores
def select_top_boundary(boundaries_list, scores_list, score_thr):
"""Select poly boundaries with scores >= score_thr.
Args:
boundaries_list (list[list[list[float]]]): List of boundaries.
The 1st, 2nd, and 3rd indices are for image, text and
vertice, respectively.
scores_list (list(list[float])): List of lists of scores.
score_thr (float): The score threshold to filter out bboxes.
Returns:
selected_bboxes (list[list[list[float]]]): List of boundaries.
The 1st, 2nd, and 3rd indices are for image, text and vertice,
respectively.
"""
assert isinstance(boundaries_list, list)
assert isinstance(scores_list, list)
assert isinstance(score_thr, float)
assert len(boundaries_list) == len(scores_list)
assert 0 <= score_thr <= 1
selected_boundaries = []
for boundary, scores in zip(boundaries_list, scores_list):
if len(scores) > 0:
assert len(scores) == len(boundary)
inds = [
iter for iter in range(len(scores))
if scores[iter] >= score_thr
]
selected_boundaries.append([boundary[i] for i in inds])
else:
selected_boundaries.append(boundary)
return selected_boundaries
def select_bboxes_via_score(bboxes_list, scores_list, score_thr):
"""Select bboxes with scores >= score_thr.
Args:
bboxes_list (list[ndarray]): List of bboxes. Each element is ndarray of
shape (n,8)
scores_list (list(list[float])): List of lists of scores.
score_thr (float): The score threshold to filter out bboxes.
Returns:
selected_bboxes (list[ndarray]): List of bboxes. Each element is
ndarray of shape (m,8) with m<=n.
"""
assert isinstance(bboxes_list, list)
assert isinstance(scores_list, list)
assert isinstance(score_thr, float)
assert len(bboxes_list) == len(scores_list)
assert 0 <= score_thr <= 1
selected_bboxes = []
for bboxes, scores in zip(bboxes_list, scores_list):
if len(scores) > 0:
assert len(scores) == bboxes.shape[0]
inds = [
iter for iter in range(len(scores))
if scores[iter] >= score_thr
]
selected_bboxes.append(bboxes[inds, :])
else:
selected_bboxes.append(bboxes)
return selected_bboxes

View File

@ -12,10 +12,9 @@ from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness
from shapely.geometry import Polygon as plg
import mmocr.core.evaluation.utils as eval_utils
from mmocr.registry import TRANSFORMS
from mmocr.utils import (bbox2poly, crop_polygon, is_poly_inside_rect,
poly2bbox, rescale_polygon)
poly2bbox, poly_intersection, rescale_polygon)
from .wrappers import ImgAug
@ -623,8 +622,7 @@ class TextDetRandomCropFlip(BaseTransform):
success_flag = True
for poly_idx, polygon in enumerate(polygons):
ppi = plg(polygon.reshape(-1, 2))
# TODO Move this eval_utils to point_utils?
ppiou = eval_utils.poly_intersection(ppi, pp)
ppiou = poly_intersection(ppi, pp)
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
np.abs(ppiou) > self.epsilon:
success_flag = False

View File

@ -9,9 +9,9 @@ from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import maximum_bipartite_matching
from shapely.geometry import Polygon
from mmocr.core.evaluation.utils import compute_hmean
from mmocr.registry import METRICS
from mmocr.utils import poly_intersection, poly_iou, polys2shapely
from mmocr.utils import (compute_hmean, poly_intersection, poly_iou,
polys2shapely)
@METRICS.register_module()

View File

@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
from mmocr.core import TextDetDataSample
from mmocr.core.evaluation.utils import boundary_iou
from mmocr.utils import is_type_list, rescale_polygons
from mmocr.utils import boundary_iou, is_type_list, rescale_polygons
class BasePostprocessor:

View File

@ -2,22 +2,27 @@
from mmcv.utils import Registry, build_from_cfg
from .api_utils import disable_text_recog_aug_test
from .bbox_utils import bbox2poly, rescale_bboxes
from .bbox_utils import (bbox2poly, box_center_distance, box_diag,
rescale_bboxes)
from .box_util import (bezier_to_polygon, is_on_same_line, sort_points,
stitch_boxes_into_lines)
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
is_type_list, valid_boundary)
from .collect_env import collect_env
from .data_convert_util import dump_ocr_data, recog_anno_to_imginfo
from .evaluation_utils import (compute_hmean, filter_2dlist_result,
many2one_match_ic13, one2one_match_ic13,
select_top_boundary)
from .fileio import list_from_file, list_to_file
from .lmdb_util import recog2lmdb
from .logger import get_root_logger
from .model import revert_sync_batchnorm
from .point_utils import dist_points2line
from .polygon_utils import (crop_polygon, is_poly_inside_rect, offset_polygon,
poly2bbox, poly2shapely, poly_intersection,
poly_iou, poly_make_valid, poly_union,
polys2shapely, rescale_polygon, rescale_polygons)
from .point_utils import dist_points2line, point_distance, points_center
from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect,
offset_polygon, poly2bbox, poly2shapely,
poly_intersection, poly_iou, poly_make_valid,
poly_union, polys2shapely, rescale_polygon,
rescale_polygons)
from .setup_env import register_all_modules
from .string_util import StringStrip
@ -31,5 +36,8 @@ __all__ = [
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect',
'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid',
'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules',
'dist_points2line', 'offset_polygon', 'disable_text_recog_aug_test'
'dist_points2line', 'offset_polygon', 'disable_text_recog_aug_test',
'box_center_distance', 'box_diag', 'compute_hmean', 'filter_2dlist_result',
'many2one_match_ic13', 'one2one_match_ic13', 'select_top_boundary',
'point_distance', 'points_center', 'boundary_iou'
]

View File

@ -4,6 +4,8 @@ from typing import Tuple
import numpy as np
from numpy.typing import ArrayLike
from mmocr.utils.point_utils import point_distance, points_center
def rescale_bbox(bbox: np.ndarray,
scale_factor: Tuple[int, int],
@ -74,3 +76,18 @@ def bbox2poly(bbox: ArrayLike) -> np.array:
return np.array([
bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]
])
def box_center_distance(b1, b2):
# TODO typehints & docstring & test
assert isinstance(b1, np.ndarray)
assert isinstance(b2, np.ndarray)
return point_distance(points_center(b1), points_center(b2))
def box_diag(box):
# TODO typehints & docstring & test
assert isinstance(box, np.ndarray)
assert box.size == 8
return point_distance(box[0:2], box[4:6])

View File

@ -0,0 +1,199 @@
# Copyright (c) OpenMMLab. All rights reserved.
# TODO check whether to keep these utils after refactoring ic13 metrics
import numpy as np
def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num):
# TODO Add typehints & Test
"""Compute hmean given hit number, ground truth number and prediction
number.
Args:
accum_hit_recall (int|float): Accumulated hits for computing recall.
accum_hit_prec (int|float): Accumulated hits for computing precision.
gt_num (int): Ground truth number.
pred_num (int): Prediction number.
Returns:
recall (float): The recall value.
precision (float): The precision value.
hmean (float): The hmean value.
"""
assert isinstance(accum_hit_recall, (float, int))
assert isinstance(accum_hit_prec, (float, int))
assert isinstance(gt_num, int)
assert isinstance(pred_num, int)
assert accum_hit_recall >= 0.0
assert accum_hit_prec >= 0.0
assert gt_num >= 0.0
assert pred_num >= 0.0
if gt_num == 0:
recall = 1.0
precision = 0.0 if pred_num > 0 else 1.0
else:
recall = float(accum_hit_recall) / gt_num
precision = 0.0 if pred_num == 0 else float(accum_hit_prec) / pred_num
denom = recall + precision
hmean = 0.0 if denom == 0 else (2.0 * precision * recall / denom)
return recall, precision, hmean
def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr,
precision_thr):
# TODO Add typehints & Test
"""One-to-One match gt and det with icdar2013 standards.
Args:
gt_id (int): The ground truth id index.
det_id (int): The detection result id index.
recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the recall ratio of gt i to det j.
precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the precision ratio of gt i to det j.
recall_thr (float): The recall threshold.
precision_thr (float): The precision threshold.
Returns:
True|False: Whether the gt and det are matched.
"""
assert isinstance(gt_id, int)
assert isinstance(det_id, int)
assert isinstance(recall_mat, np.ndarray)
assert isinstance(precision_mat, np.ndarray)
assert 0 <= recall_thr <= 1
assert 0 <= precision_thr <= 1
cont = 0
for i in range(recall_mat.shape[1]):
if recall_mat[gt_id,
i] > recall_thr and precision_mat[gt_id,
i] > precision_thr:
cont += 1
if cont != 1:
return False
cont = 0
for i in range(recall_mat.shape[0]):
if recall_mat[i, det_id] > recall_thr and precision_mat[
i, det_id] > precision_thr:
cont += 1
if cont != 1:
return False
if recall_mat[gt_id, det_id] > recall_thr and precision_mat[
gt_id, det_id] > precision_thr:
return True
return False
def many2one_match_ic13(det_id, recall_mat, precision_mat, recall_thr,
precision_thr, gt_match_flag, det_match_flag,
gt_ignored_index):
# TODO Add typehints & Test
"""Many-to-One match gt and detections with icdar2013 standards.
Args:
det_id (int): Detection index.
recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the recall ratio of gt i to det j.
precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the precision ratio of gt i to det j.
recall_thr (float): The recall threshold.
precision_thr (float): The precision threshold.
gt_match_flag (ndarray): An array indicates each gt has been matched
already.
det_match_flag (ndarray): An array indicates each detection box has
been matched already or not.
gt_ignored_index (list): A list indicates each gt box can be ignored
or not.
Returns:
tuple (True|False, list): The first indicates the detection is matched
or not; the second is the matched gt ids.
"""
assert isinstance(det_id, int)
assert isinstance(recall_mat, np.ndarray)
assert isinstance(precision_mat, np.ndarray)
assert 0 <= recall_thr <= 1
assert 0 <= precision_thr <= 1
assert isinstance(gt_match_flag, list)
assert isinstance(det_match_flag, list)
assert isinstance(gt_ignored_index, list)
many_sum = 0.
gt_ids = []
for gt_id in range(recall_mat.shape[0]):
if gt_match_flag[gt_id] == 0 and det_match_flag[
det_id] == 0 and gt_id not in gt_ignored_index:
if recall_mat[gt_id, det_id] >= recall_thr:
many_sum += precision_mat[gt_id, det_id]
gt_ids.append(gt_id)
if many_sum >= precision_thr:
return True, gt_ids
return False, []
def filter_2dlist_result(results, scores, score_thr):
# TODO Add typehints & Test
"""Find out detected results whose score > score_thr.
Args:
results (list[list[float]]): The result list.
score (list): The score list.
score_thr (float): The score threshold.
Returns:
valid_results (list[list[float]]): The valid results.
valid_score (list[float]): The scores which correspond to the valid
results.
"""
assert isinstance(results, list)
assert len(results) == len(scores)
assert isinstance(score_thr, float)
assert 0 <= score_thr <= 1
inds = np.array(scores) > score_thr
valid_results = [results[idx] for idx in np.where(inds)[0].tolist()]
valid_scores = [scores[idx] for idx in np.where(inds)[0].tolist()]
return valid_results, valid_scores
def select_top_boundary(boundaries_list, scores_list, score_thr):
# TODO Add typehints & Test
"""Select poly boundaries with scores >= score_thr.
Args:
boundaries_list (list[list[list[float]]]): List of boundaries.
The 1st, 2nd, and 3rd indices are for image, text and
vertice, respectively.
scores_list (list(list[float])): List of lists of scores.
score_thr (float): The score threshold to filter out bboxes.
Returns:
selected_bboxes (list[list[list[float]]]): List of boundaries.
The 1st, 2nd, and 3rd indices are for image, text and vertice,
respectively.
"""
assert isinstance(boundaries_list, list)
assert isinstance(scores_list, list)
assert isinstance(score_thr, float)
assert len(boundaries_list) == len(scores_list)
assert 0 <= score_thr <= 1
selected_boundaries = []
for boundary, scores in zip(boundaries_list, scores_list):
if len(scores) > 0:
assert len(scores) == len(boundary)
inds = [
iter for iter in range(len(scores))
if scores[iter] >= score_thr
]
selected_boundaries.append([boundary[i] for i in inds])
else:
selected_boundaries.append(boundary)
return selected_boundaries

View File

@ -34,3 +34,26 @@ def dist_points2line(xs, ys, pt1, pt2):
# set result to minimum edge if C<pi/2
result[neg_cos_c < 0] = np.sqrt(np.fmin(a_square, b_square))[neg_cos_c < 0]
return result
def points_center(points):
# TODO typehints & docstring
assert isinstance(points, np.ndarray)
assert points.size % 2 == 0
points = points.reshape([-1, 2])
return np.mean(points, axis=0)
def point_distance(p1, p2):
# TODO typehints & docstring
assert isinstance(p1, np.ndarray)
assert isinstance(p2, np.ndarray)
assert p1.size == 2
assert p2.size == 2
dist = np.square(p2 - p1)
dist = np.sum(dist)
dist = np.sqrt(dist)
return dist

View File

@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import pyclipper
from numpy.typing import ArrayLike
from shapely.geometry import MultiPolygon, Polygon
from mmocr.utils import bbox2poly
from mmocr.utils import bbox2poly, valid_boundary
def rescale_polygon(polygon: ArrayLike,
@ -314,3 +314,25 @@ def offset_polygon(poly: ArrayLike, distance: float) -> ArrayLike:
# But when the resulting polygon is invalid, return the empty array
# as it is
return result if len(result) == 0 else result[0].flatten()
def boundary_iou(src: List,
target: List,
zero_division: Union[int, float] = 0) -> float:
"""Calculate the IOU between two boundaries.
Args:
src (list): Source boundary.
target (list): Target boundary.
zero_division (int or float): The return value when invalid
boundary exists.
Returns:
float: The iou between two boundaries.
"""
assert valid_boundary(src, False)
assert valid_boundary(target, False)
src_poly = poly2shapely(src)
target_poly = poly2shapely(target)
return poly_iou(src_poly, target_poly, zero_division=zero_division)

View File

@ -25,23 +25,6 @@ def test_compute_hmean():
assert hmean == 0
def test_boundary_iou():
points = [0, 0, 0, 1, 1, 1, 1, 0]
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
assert utils.boundary_iou(points, points1) == 0
# test overlapping boundaries
assert utils.boundary_iou(points, points) == 1
# test invalid boundaries
assert utils.boundary_iou(points2, points2) == 0
assert utils.boundary_iou(points3, points3, zero_division=1) == 1
assert utils.boundary_iou(points2, points3) == 0
def test_points_center():
# test unsupported type

View File

@ -5,7 +5,7 @@ import math
import pytest
import mmocr.core.evaluation.hmean_ic13 as hmean_ic13
import mmocr.core.evaluation.utils as utils
import mmocr.utils as utils
def test_compute_recall_precision():
@ -21,8 +21,8 @@ def test_compute_recall_precision():
box2 = [0, 0, 10, 0, 10, 1, 0, 1]
gt_polys = [utils.points2polygon(box1)]
det_polys = [utils.points2polygon(box2)]
gt_polys = [utils.poly2shapely(box1)]
det_polys = [utils.poly2shapely(box2)]
recall, precision = hmean_ic13.compute_recall_precision(
gt_polys, det_polys)
assert recall == 1

View File

@ -6,8 +6,8 @@ import torch
from nose_parameterized import parameterized
from mmocr.core import TextDetDataSample
from mmocr.core.evaluation.utils import points2polygon, poly_iou
from mmocr.models.textdet.postprocessors import PANPostprocessor
from mmocr.utils import poly2shapely, poly_iou
class TestPANPostprocessor(unittest.TestCase):
@ -44,8 +44,8 @@ class TestPANPostprocessor(unittest.TestCase):
postprocessor = PANPostprocessor(text_repr_type='quad')
result = postprocessor._points2boundary(points)
pred_poly = points2polygon(result)
target_poly = points2polygon([2, 2, 0, 2, 0, 0, 2, 0])
pred_poly = poly2shapely(result)
target_poly = poly2shapely([2, 2, 0, 2, 0, 0, 2, 0])
self.assertEqual(poly_iou(pred_poly, target_poly), 1)
result = postprocessor._points2boundary(points, min_width=3)
@ -54,6 +54,6 @@ class TestPANPostprocessor(unittest.TestCase):
# test poly
postprocessor = PANPostprocessor(text_repr_type='poly')
result = postprocessor._points2boundary(points)
pred_poly = points2polygon(result)
target_poly = points2polygon([0, 0, 0, 2, 2, 2, 2, 0])
pred_poly = poly2shapely(result)
target_poly = poly2shapely([0, 0, 0, 2, 2, 2, 2, 0])
assert poly_iou(pred_poly, target_poly) == 1

View File

@ -5,10 +5,10 @@ import numpy as np
import torch
from shapely.geometry import MultiPolygon, Polygon
from mmocr.utils import (crop_polygon, offset_polygon, poly2bbox, poly2shapely,
poly_intersection, poly_iou, poly_make_valid,
poly_union, polys2shapely, rescale_polygon,
rescale_polygons)
from mmocr.utils import (boundary_iou, crop_polygon, offset_polygon, poly2bbox,
poly2shapely, poly_intersection, poly_iou,
poly_make_valid, poly_union, polys2shapely,
rescale_polygon, rescale_polygons)
class TestCropPolygon(unittest.TestCase):
@ -281,3 +281,19 @@ class TestPolygonUtils(unittest.TestCase):
dtype=np.float32)
shrunk = offset_polygon(polygons, -1)
self.assertEqual(len(shrunk), 0)
def test_boundary_iou(self):
points = [0, 0, 0, 1, 1, 1, 1, 0]
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
self.assertEqual(boundary_iou(points, points1), 0)
# test overlapping boundaries
self.assertEqual(boundary_iou(points, points), 1)
# test invalid boundaries
self.assertEqual(boundary_iou(points2, points2), 0)
self.assertEqual(boundary_iou(points3, points3, zero_division=1), 1)
self.assertEqual(boundary_iou(points2, points3), 0)