mirror of https://github.com/open-mmlab/mmocr.git
New Hmean-iou metric
parent
f47f3eff03
commit
7e7a526f37
|
@ -21,3 +21,6 @@ mmocr/models/textdet/detectors/text_detector_mixin.py
|
|||
|
||||
# It will be covered by tests of any det model implemented in future
|
||||
mmocr/models/textdet/detectors/single_stage_text_detector.py
|
||||
|
||||
# It will be removed after all utils are moved to mmocr.utils
|
||||
mmocr/core/evaluation/utils.py
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hmean import eval_hmean
|
||||
from .hmean_ic13 import eval_hmean_ic13
|
||||
from .hmean_iou import eval_hmean_iou
|
||||
from .kie_metric import compute_f1_score
|
||||
from .ner_metric import eval_ner_f1
|
||||
from .ocr_metric import eval_ocr_metric
|
||||
|
||||
__all__ = [
|
||||
'eval_hmean_ic13', 'eval_hmean_iou', 'eval_ocr_metric', 'eval_hmean',
|
||||
'compute_f1_score', 'eval_ner_f1'
|
||||
'eval_hmean_ic13', 'eval_ocr_metric', 'eval_hmean', 'compute_f1_score',
|
||||
'eval_ner_f1'
|
||||
]
|
||||
|
|
|
@ -5,55 +5,6 @@ from shapely.geometry import Polygon as plg
|
|||
import mmocr.utils as utils
|
||||
|
||||
|
||||
def ignore_pred(pred_boxes, gt_ignored_index, gt_polys, precision_thr):
|
||||
"""Ignore the predicted box if it hits any ignored ground truth.
|
||||
|
||||
Args:
|
||||
pred_boxes (list[ndarray or list]): The predicted boxes of one image.
|
||||
gt_ignored_index (list[int]): The ignored ground truth index list.
|
||||
gt_polys (list[Polygon]): The polygon list of one image.
|
||||
precision_thr (float): The precision threshold.
|
||||
|
||||
Returns:
|
||||
pred_polys (list[Polygon]): The predicted polygon list.
|
||||
pred_points (list[list]): The predicted box list represented
|
||||
by point sequences.
|
||||
pred_ignored_index (list[int]): The ignored text index list.
|
||||
"""
|
||||
|
||||
assert isinstance(pred_boxes, list)
|
||||
assert isinstance(gt_ignored_index, list)
|
||||
assert isinstance(gt_polys, list)
|
||||
assert 0 <= precision_thr <= 1
|
||||
|
||||
pred_polys = []
|
||||
pred_points = []
|
||||
pred_ignored_index = []
|
||||
|
||||
gt_ignored_num = len(gt_ignored_index)
|
||||
# get detection polygons
|
||||
for box_id, box in enumerate(pred_boxes):
|
||||
poly = points2polygon(box)
|
||||
pred_polys.append(poly)
|
||||
pred_points.append(box)
|
||||
|
||||
if gt_ignored_num < 1:
|
||||
continue
|
||||
|
||||
# ignore the current detection box
|
||||
# if its overlap with any ignored gt > precision_thr
|
||||
for ignored_box_id in gt_ignored_index:
|
||||
ignored_box = gt_polys[ignored_box_id]
|
||||
inter_area = poly_intersection(poly, ignored_box)
|
||||
area = poly.area
|
||||
precision = 0 if area == 0 else inter_area / area
|
||||
if precision > precision_thr:
|
||||
pred_ignored_index.append(box_id)
|
||||
break
|
||||
|
||||
return pred_polys, pred_points, pred_ignored_index
|
||||
|
||||
|
||||
def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num):
|
||||
"""Compute hmean given hit number, ground truth number and prediction
|
||||
number.
|
||||
|
@ -95,6 +46,7 @@ def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num):
|
|||
|
||||
|
||||
def box2polygon(box):
|
||||
# TODO This has been moved to mmocr.utils. Delete this later
|
||||
"""Convert box to polygon.
|
||||
|
||||
Args:
|
||||
|
@ -117,6 +69,7 @@ def box2polygon(box):
|
|||
|
||||
|
||||
def points2polygon(points):
|
||||
# TODO This has been moved to mmocr.utils. Delete this later
|
||||
"""Convert k points to 1 polygon.
|
||||
|
||||
Args:
|
||||
|
@ -137,6 +90,7 @@ def points2polygon(points):
|
|||
|
||||
|
||||
def poly_make_valid(poly):
|
||||
# TODO This has been moved to mmocr.utils. Delete this later
|
||||
"""Convert a potentially invalid polygon to a valid one by eliminating
|
||||
self-crossing or self-touching parts.
|
||||
|
||||
|
@ -150,6 +104,7 @@ def poly_make_valid(poly):
|
|||
|
||||
|
||||
def poly_intersection(poly_det, poly_gt, invalid_ret=None, return_poly=False):
|
||||
# TODO This has been moved to mmocr.utils. Delete this later
|
||||
"""Calculate the intersection area between two polygon.
|
||||
|
||||
Args:
|
||||
|
@ -185,6 +140,7 @@ def poly_intersection(poly_det, poly_gt, invalid_ret=None, return_poly=False):
|
|||
|
||||
|
||||
def poly_union(poly_det, poly_gt, invalid_ret=None, return_poly=False):
|
||||
# TODO This has been moved to mmocr.utils. Delete this later
|
||||
"""Calculate the union area between two polygon.
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
|
@ -241,6 +197,7 @@ def boundary_iou(src, target, zero_division=0):
|
|||
|
||||
|
||||
def poly_iou(poly_det, poly_gt, zero_division=0):
|
||||
# TODO This has been moved to mmocr.utils. Delete this later
|
||||
"""Calculate the IOU between two polygons.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hmean_iou_metric import HmeanIOUMetric
|
||||
|
||||
__all__ = ['HmeanIOUMetric']
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger
|
||||
from scipy.sparse import csr_matrix
|
||||
from scipy.sparse.csgraph import maximum_bipartite_matching
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.core.evaluation.utils import compute_hmean
|
||||
from mmocr.registry import METRICS
|
||||
from mmocr.utils import poly_intersection, poly_iou, polys2shapely
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class HmeanIOUMetric(BaseMetric):
|
||||
"""HmeanIOU metric.
|
||||
|
||||
This method computes the hmean iou metric, which is done in the
|
||||
following steps:
|
||||
|
||||
- Filter the prediction polygon:
|
||||
|
||||
- Scores is smaller than minimum prediction score threshold.
|
||||
- The proportion of the area that intersects with gt ignored polygon is
|
||||
greater than ignore_precision_thr.
|
||||
|
||||
- Computing an M x N IoU matrix, where each element indexing
|
||||
E_mn represents the IoU between the m-th valid GT and n-th valid
|
||||
prediction.
|
||||
- Based on different prediction score threshold:
|
||||
- Obtain the ignored predictions according to prediction score.
|
||||
The filtered predictions will not be involved in the later metric
|
||||
computations.
|
||||
- Based on the IoU matrix, get the match metric according to
|
||||
``match_iou_thr``.
|
||||
- Based on different `strategy`, accumulate the match number.
|
||||
- calculate H-mean under different prediction score threshold.
|
||||
|
||||
Args:
|
||||
match_iou_thr (float): IoU threshold for a match. Defaults to 0.5.
|
||||
ignore_precision_thr (float): Precision threshold when prediction and\
|
||||
gt ignored polygons are matched. Defaults to 0.5.
|
||||
pred_score_thrs (dict): Best prediction score threshold searching
|
||||
space. Defaults to dict(start=0.3, stop=0.9, step=0.1).
|
||||
strategy (str): Polygon matching strategy. Options are 'max_matching'
|
||||
and 'vanilla'. 'max_matching' refers to the optimum strategy that
|
||||
maximizes the number of matches. Vanilla strategy matches gt and
|
||||
pred polygons if both of them are never matched before. It was used
|
||||
in MMOCR 0.x and is not recommended to use now. Defaults to
|
||||
'max_matching'.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None
|
||||
"""
|
||||
default_prefix: Optional[str] = 'icdar'
|
||||
|
||||
def __init__(self,
|
||||
match_iou_thr: float = 0.5,
|
||||
ignore_precision_thr: float = 0.5,
|
||||
pred_score_thrs: Dict = dict(start=0.3, stop=0.9, step=0.1),
|
||||
strategy: str = 'max_matching',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
self.match_iou_thr = match_iou_thr
|
||||
self.ignore_precision_thr = ignore_precision_thr
|
||||
self.pred_score_thrs = np.arange(**pred_score_thrs)
|
||||
assert strategy in ['max_matching', 'vanilla']
|
||||
self.strategy = strategy
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
"""Process one batch of data samples and predictions. The processed
|
||||
results should be stored in ``self.results``, which will be used to
|
||||
compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of data from dataloader.
|
||||
predictions (Sequence[Dict]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
for pred, gt in zip(predictions, data_batch):
|
||||
|
||||
pred_instances = pred.get('pred_instances')
|
||||
pred_polygons = pred_instances.get('polygons')
|
||||
pred_scores = pred_instances.get('scores')
|
||||
if isinstance(pred_scores, torch.Tensor):
|
||||
pred_scores = pred_scores.cpu().numpy()
|
||||
pred_scores = np.array(pred_scores, dtype=np.float32)
|
||||
|
||||
gt_polys, gt_ignore_flags = self._polys_from_ann(
|
||||
gt['data_sample']['instances'])
|
||||
gt_polys = polys2shapely(gt_polys)
|
||||
pred_polys = polys2shapely(pred_polygons)
|
||||
|
||||
pred_ignore_flags = self._filter_preds(pred_polys, gt_polys,
|
||||
pred_scores,
|
||||
gt_ignore_flags)
|
||||
|
||||
gt_num = np.sum(~gt_ignore_flags)
|
||||
pred_num = np.sum(~pred_ignore_flags)
|
||||
iou_metric = np.zeros([gt_num, pred_num])
|
||||
|
||||
# Compute IoU scores amongst kept pred and gt polygons
|
||||
for pred_mat_id, pred_poly_id in enumerate(
|
||||
self._true_indexes(~pred_ignore_flags)):
|
||||
for gt_mat_id, gt_poly_id in enumerate(
|
||||
self._true_indexes(~gt_ignore_flags)):
|
||||
iou_metric[gt_mat_id, pred_mat_id] = poly_iou(
|
||||
gt_polys[gt_poly_id], pred_polys[pred_poly_id])
|
||||
|
||||
result = dict(
|
||||
iou_metric=iou_metric,
|
||||
pred_scores=pred_scores[~pred_ignore_flags])
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List[Dict]) -> Dict:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list[dict]): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
|
||||
best_eval_results = dict(hmean=-1)
|
||||
logger.info('Evaluating hmean-iou...')
|
||||
|
||||
dataset_pred_num = np.zeros_like(self.pred_score_thrs)
|
||||
dataset_hit_num = np.zeros_like(self.pred_score_thrs)
|
||||
dataset_gt_num = 0
|
||||
|
||||
for result in results:
|
||||
iou_metric = result['iou_metric'] # (gt_num, pred_num)
|
||||
pred_scores = result['pred_scores'] # (pred_num)
|
||||
dataset_gt_num += iou_metric.shape[0]
|
||||
|
||||
# Filter out predictions by IoU threshold
|
||||
for i, pred_score_thr in enumerate(self.pred_score_thrs):
|
||||
pred_ignore_flags = pred_scores < pred_score_thr
|
||||
# get the number of matched boxes
|
||||
matched_metric = iou_metric[:, ~pred_ignore_flags] \
|
||||
> self.match_iou_thr
|
||||
if self.strategy == 'max_matching':
|
||||
csr_matched_metric = csr_matrix(matched_metric)
|
||||
matched_preds = maximum_bipartite_matching(
|
||||
csr_matched_metric, perm_type='row')
|
||||
# -1 denotes unmatched pred polygons
|
||||
dataset_hit_num[i] += np.sum(matched_preds != -1)
|
||||
else:
|
||||
# first come first matched
|
||||
matched_gt_indexes = set()
|
||||
matched_pred_indexes = set()
|
||||
for gt_idx, pred_idx in zip(*np.nonzero(matched_metric)):
|
||||
if gt_idx in matched_gt_indexes or \
|
||||
pred_idx in matched_pred_indexes:
|
||||
continue
|
||||
matched_gt_indexes.add(gt_idx)
|
||||
matched_pred_indexes.add(pred_idx)
|
||||
dataset_hit_num[i] += len(matched_gt_indexes)
|
||||
dataset_pred_num[i] += np.sum(~pred_ignore_flags)
|
||||
|
||||
for i, pred_score_thr in enumerate(self.pred_score_thrs):
|
||||
precision, recall, hmean = compute_hmean(
|
||||
int(dataset_hit_num[i]), int(dataset_hit_num[i]),
|
||||
int(dataset_gt_num), int(dataset_pred_num[i]))
|
||||
eval_results = dict(
|
||||
precision=precision, recall=recall, hmean=hmean)
|
||||
logger.info(f'prediction score threshold: {pred_score_thr}, '
|
||||
f'recall: {eval_results["recall"]:.3f}, '
|
||||
f'precision: {eval_results["precision"]:.3f}, '
|
||||
f'hmean: {eval_results["hmean"]:.3f}\n')
|
||||
if eval_results['hmean'] > best_eval_results['hmean']:
|
||||
best_eval_results = eval_results
|
||||
return best_eval_results
|
||||
|
||||
def _filter_preds(self, pred_polys: List[Polygon], gt_polys: List[Polygon],
|
||||
pred_scores: List[float],
|
||||
gt_ignore_flags: np.ndarray) -> np.ndarray:
|
||||
"""Filter out the predictions by score threshold and whether it
|
||||
overlaps ignored gt polygons.
|
||||
|
||||
Args:
|
||||
pred_polys (list[Polygon]): Pred polygons.
|
||||
gt_polys (list[Polygon]): GT polygons.
|
||||
pred_scores (list[float]): Pred scores of polygons.
|
||||
gt_ignore_flags (np.ndarray): 1D boolean array indicating
|
||||
the positions of ignored gt polygons.
|
||||
|
||||
Returns:
|
||||
np.ndarray: 1D boolean array indicating the positions of ignored
|
||||
pred polygons.
|
||||
"""
|
||||
|
||||
# Filter out predictions based on the minimum score threshold
|
||||
pred_ignore_flags = pred_scores < self.pred_score_thrs.min()
|
||||
|
||||
# Filter out pred polygons which overlaps any ignored gt polygons
|
||||
for pred_id in self._true_indexes(~pred_ignore_flags):
|
||||
for gt_id in self._true_indexes(gt_ignore_flags):
|
||||
# Match pred with ignored gt
|
||||
precision = poly_intersection(
|
||||
gt_polys[gt_id], pred_polys[pred_id]) / (
|
||||
pred_polys[pred_id].area + 1e-5)
|
||||
if precision > self.ignore_precision_thr:
|
||||
pred_ignore_flags[pred_id] = True
|
||||
break
|
||||
|
||||
return pred_ignore_flags
|
||||
|
||||
def _true_indexes(self, array: np.ndarray) -> np.ndarray:
|
||||
"""Get indexes of True elements from a 1D boolean array."""
|
||||
return np.where(array)[0]
|
||||
|
||||
def _polys_from_ann(self, ann: Dict) -> Tuple[List, List]:
|
||||
"""Get GT polygons from annotations.
|
||||
|
||||
Args:
|
||||
ann (dict): The ground-truth annotation.
|
||||
|
||||
Returns:
|
||||
tuple[list[np.array], np.array]: Returns a tuple
|
||||
``(polys, gt_ignore_flags)``, where ``polys`` is the ground-truth
|
||||
polygon instances and ``gt_ignore_flags`` represents whether the
|
||||
corresponding instance should be ignored.
|
||||
"""
|
||||
polys = []
|
||||
gt_ignore_flags = []
|
||||
for instance in ann:
|
||||
gt_ignore_flags.append(instance['ignore'])
|
||||
polys.append(np.array(instance['polygon'], dtype=np.float32))
|
||||
return polys, np.array(gt_ignore_flags, dtype=bool)
|
|
@ -15,6 +15,8 @@ from .lmdb_util import recog2lmdb
|
|||
from .logger import get_root_logger
|
||||
from .model import revert_sync_batchnorm
|
||||
from .polygon_utils import (crop_polygon, is_poly_outside_rect, poly2bbox,
|
||||
poly2shapely, poly_intersection, poly_iou,
|
||||
poly_make_valid, poly_union, polys2shapely,
|
||||
rescale_polygon, rescale_polygons)
|
||||
from .setup_env import setup_multi_processes
|
||||
from .string_util import StringStrip
|
||||
|
@ -22,12 +24,12 @@ from .string_util import StringStrip
|
|||
__all__ = [
|
||||
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
|
||||
'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist',
|
||||
'valid_boundary', 'lmdb_converter', 'drop_orientation',
|
||||
'convert_annotations', 'is_not_png', 'list_to_file', 'list_from_file',
|
||||
'is_on_same_line', 'stitch_boxes_into_lines', 'StringStrip',
|
||||
'revert_sync_batchnorm', 'bezier_to_polygon', 'sort_points',
|
||||
'setup_multi_processes', 'recog2lmdb', 'dump_ocr_data',
|
||||
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
|
||||
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_outside_rect',
|
||||
'poly2bbox'
|
||||
'valid_boundary', 'drop_orientation', 'convert_annotations', 'is_not_png',
|
||||
'list_to_file', 'list_from_file', 'is_on_same_line',
|
||||
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
|
||||
'bezier_to_polygon', 'sort_points', 'setup_multi_processes', 'recog2lmdb',
|
||||
'dump_ocr_data', 'recog_anno_to_imginfo', 'rescale_polygons',
|
||||
'rescale_polygon', 'rescale_bboxes', 'bbox2poly', 'crop_polygon',
|
||||
'is_poly_outside_rect', 'poly2bbox', 'poly_intersection', 'poly_iou',
|
||||
'poly_make_valid', 'poly_union', 'poly2shapely', 'polys2shapely'
|
||||
]
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
from shapely.geometry import Polygon
|
||||
from shapely.geometry import MultiPolygon, Polygon
|
||||
|
||||
from mmocr.utils import bbox2poly
|
||||
|
||||
|
@ -91,6 +91,34 @@ def poly2bbox(polygon: ArrayLike) -> np.array:
|
|||
return np.array([min(x), min(y), max(x), max(y)])
|
||||
|
||||
|
||||
def poly2shapely(polygon: ArrayLike) -> Polygon:
|
||||
"""Convert a polygon to shapely.geometry.Polygon.
|
||||
|
||||
Args:
|
||||
polygon (ArrayLike): A set of points of 2k shape.
|
||||
|
||||
Returns:
|
||||
polygon (Polygon): A polygon object.
|
||||
"""
|
||||
assert len(polygon) % 2 == 0 and len(polygon) >= 8
|
||||
polygon = np.array(polygon, dtype=np.float32)
|
||||
|
||||
polygon = polygon.reshape([-1, 2])
|
||||
return Polygon(polygon)
|
||||
|
||||
|
||||
def polys2shapely(polygons: Sequence[ArrayLike]) -> Sequence[Polygon]:
|
||||
"""Convert a nested list of boundaries to a list of Polygons.
|
||||
|
||||
Args:
|
||||
polygons (list): The point coordinates of the instance boundary.
|
||||
|
||||
Returns:
|
||||
list: Converted shapely.Polygon.
|
||||
"""
|
||||
return [poly2shapely(polygon) for polygon in polygons]
|
||||
|
||||
|
||||
def crop_polygon(polygon: ArrayLike, crop_box: np.ndarray) -> np.ndarray:
|
||||
"""Crop polygon to be within a box region.
|
||||
|
||||
|
@ -114,6 +142,124 @@ def crop_polygon(polygon: ArrayLike, crop_box: np.ndarray) -> np.ndarray:
|
|||
return poly_cropped.reshape(-1)
|
||||
|
||||
|
||||
def poly_make_valid(poly: Polygon) -> Polygon:
|
||||
"""Convert a potentially invalid polygon to a valid one by eliminating
|
||||
self-crossing or self-touching parts.
|
||||
|
||||
Args:
|
||||
poly (Polygon): A polygon needed to be converted.
|
||||
|
||||
Returns:
|
||||
Polygon: A valid polygon.
|
||||
"""
|
||||
assert isinstance(poly, Polygon)
|
||||
return poly if poly.is_valid else poly.buffer(0)
|
||||
|
||||
|
||||
def poly_intersection(poly_a: Polygon,
|
||||
poly_b: Polygon,
|
||||
invalid_ret: Optional[Union[float, int]] = None,
|
||||
return_poly: bool = False
|
||||
) -> Tuple[float, Optional[Polygon]]:
|
||||
"""Calculate the intersection area between two polygons.
|
||||
|
||||
Args:
|
||||
poly_a (Polygon): Polygon a.
|
||||
poly_b (Polygon): Polygon b.
|
||||
invalid_ret (float or int, optional): The return value when the
|
||||
invalid polygon exists. If it is not specified, the function
|
||||
allows the computation to proceed with invalid polygons by
|
||||
cleaning the their self-touching or self-crossing parts.
|
||||
Defaults to None.
|
||||
return_poly (bool): Whether to return the polygon of the intersection
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
float or tuple(float, Polygon): Returns the intersection area or
|
||||
a tuple ``(area, Optional[poly_obj])``, where the `area` is the
|
||||
intersection area between two polygons and `poly_obj` is The Polygon
|
||||
object of the intersection area. Set as `None` if the input is invalid.
|
||||
Set as `None` if the input is invalid. `poly_obj` will be returned
|
||||
only if `return_poly` is `True`.
|
||||
"""
|
||||
assert isinstance(poly_a, Polygon)
|
||||
assert isinstance(poly_b, Polygon)
|
||||
assert invalid_ret is None or isinstance(invalid_ret, (float, int))
|
||||
|
||||
if invalid_ret is None:
|
||||
poly_a = poly_make_valid(poly_a)
|
||||
poly_b = poly_make_valid(poly_b)
|
||||
|
||||
poly_obj = None
|
||||
area = invalid_ret
|
||||
if poly_a.is_valid and poly_b.is_valid:
|
||||
poly_obj = poly_a.intersection(poly_b)
|
||||
area = poly_obj.area
|
||||
return (area, poly_obj) if return_poly else area
|
||||
|
||||
|
||||
def poly_union(
|
||||
poly_a: Polygon,
|
||||
poly_b: Polygon,
|
||||
invalid_ret: Optional[Union[float, int]] = None,
|
||||
return_poly: bool = False
|
||||
) -> Tuple[float, Optional[Union[Polygon, MultiPolygon]]]:
|
||||
"""Calculate the union area between two polygons.
|
||||
Args:
|
||||
poly_a (Polygon): Polygon a.
|
||||
poly_b (Polygon): Polygon b.
|
||||
invalid_ret (float or int, optional): The return value when the
|
||||
invalid polygon exists. If it is not specified, the function
|
||||
allows the computation to proceed with invalid polygons by
|
||||
cleaning the their self-touching or self-crossing parts.
|
||||
Defaults to False.
|
||||
return_poly (bool): Whether to return the polygon of the union.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
tuple: Returns a tuple ``(area, Optional[poly_obj])``, where
|
||||
the `area` is the union between two polygons and `poly_obj` is the
|
||||
Polygon or MultiPolygon object of the union of the inputs. The type
|
||||
of object depends on whether they intersect or not. Set as `None`
|
||||
if the input is invalid. `poly_obj` will be returned only if
|
||||
`return_poly` is `True`.
|
||||
"""
|
||||
assert isinstance(poly_a, Polygon)
|
||||
assert isinstance(poly_b, Polygon)
|
||||
assert invalid_ret is None or isinstance(invalid_ret, (float, int))
|
||||
|
||||
if invalid_ret is None:
|
||||
poly_a = poly_make_valid(poly_a)
|
||||
poly_b = poly_make_valid(poly_b)
|
||||
|
||||
poly_obj = None
|
||||
area = invalid_ret
|
||||
if poly_a.is_valid and poly_b.is_valid:
|
||||
poly_obj = poly_a.union(poly_b)
|
||||
area = poly_obj.area
|
||||
return (area, poly_obj) if return_poly else area
|
||||
|
||||
|
||||
def poly_iou(poly_a: Polygon,
|
||||
poly_b: Polygon,
|
||||
zero_division: float = 0.) -> float:
|
||||
"""Calculate the IOU between two polygons.
|
||||
|
||||
Args:
|
||||
poly_a (Polygon): Polygon a.
|
||||
poly_b (Polygon): Polygon b.
|
||||
zero_division (float): The return value when invalid polygon exists.
|
||||
|
||||
Returns:
|
||||
float: The IoU between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_a, Polygon)
|
||||
assert isinstance(poly_b, Polygon)
|
||||
area_inters = poly_intersection(poly_a, poly_b)
|
||||
area_union = poly_union(poly_a, poly_b)
|
||||
return area_inters / area_union if area_union != 0 else zero_division
|
||||
|
||||
|
||||
def is_poly_outside_rect(poly: ArrayLike, rect: np.ndarray) -> bool:
|
||||
"""Check if the polygon is outside the target region.
|
||||
Args:
|
||||
|
|
|
@ -2,53 +2,10 @@
|
|||
"""Tests the utils of evaluation."""
|
||||
import numpy as np
|
||||
import pytest
|
||||
from shapely.geometry import MultiPolygon, Polygon
|
||||
|
||||
import mmocr.core.evaluation.utils as utils
|
||||
|
||||
|
||||
def test_ignore_pred():
|
||||
|
||||
# test invalid arguments
|
||||
box = [0, 0, 1, 0, 1, 1, 0, 1]
|
||||
det_boxes = [box]
|
||||
gt_dont_care_index = [0]
|
||||
gt_polys = [utils.points2polygon(box)]
|
||||
precision_thr = 0.5
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
det_boxes_tmp = 1
|
||||
utils.ignore_pred(det_boxes_tmp, gt_dont_care_index, gt_polys,
|
||||
precision_thr)
|
||||
with pytest.raises(AssertionError):
|
||||
gt_dont_care_index_tmp = 1
|
||||
utils.ignore_pred(det_boxes, gt_dont_care_index_tmp, gt_polys,
|
||||
precision_thr)
|
||||
with pytest.raises(AssertionError):
|
||||
gt_polys_tmp = 1
|
||||
utils.ignore_pred(det_boxes, gt_dont_care_index, gt_polys_tmp,
|
||||
precision_thr)
|
||||
with pytest.raises(AssertionError):
|
||||
precision_thr_tmp = 1.1
|
||||
utils.ignore_pred(det_boxes, gt_dont_care_index, gt_polys,
|
||||
precision_thr_tmp)
|
||||
|
||||
# test ignored cases
|
||||
result = utils.ignore_pred(det_boxes, gt_dont_care_index, gt_polys,
|
||||
precision_thr)
|
||||
assert result[2] == [0]
|
||||
# test unignored cases
|
||||
gt_dont_care_index_tmp = []
|
||||
result = utils.ignore_pred(det_boxes, gt_dont_care_index_tmp, gt_polys,
|
||||
precision_thr)
|
||||
assert result[2] == []
|
||||
|
||||
det_boxes_tmp = [[10, 10, 15, 10, 15, 15, 10, 15]]
|
||||
result = utils.ignore_pred(det_boxes_tmp, gt_dont_care_index, gt_polys,
|
||||
precision_thr)
|
||||
assert result[2] == []
|
||||
|
||||
|
||||
def test_compute_hmean():
|
||||
|
||||
# test invalid arguments
|
||||
|
@ -68,154 +25,6 @@ def test_compute_hmean():
|
|||
assert hmean == 0
|
||||
|
||||
|
||||
def test_points2polygon():
|
||||
|
||||
# test unsupported type
|
||||
with pytest.raises(AssertionError):
|
||||
points = 2
|
||||
utils.points2polygon(points)
|
||||
|
||||
# test unsupported size
|
||||
with pytest.raises(AssertionError):
|
||||
points = [1, 2, 3, 4, 5, 6, 7]
|
||||
utils.points2polygon(points)
|
||||
with pytest.raises(AssertionError):
|
||||
points = [1, 2, 3, 4, 5, 6]
|
||||
utils.points2polygon(points)
|
||||
|
||||
# test np.array
|
||||
points = np.array([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
poly = utils.points2polygon(points)
|
||||
i = 0
|
||||
for coord in poly.exterior.coords[:-1]:
|
||||
assert coord[0] == points[i]
|
||||
assert coord[1] == points[i + 1]
|
||||
i += 2
|
||||
|
||||
points = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
poly = utils.points2polygon(points)
|
||||
i = 0
|
||||
for coord in poly.exterior.coords[:-1]:
|
||||
assert coord[0] == points[i]
|
||||
assert coord[1] == points[i + 1]
|
||||
i += 2
|
||||
|
||||
|
||||
def test_poly_intersection():
|
||||
|
||||
# test unsupported type
|
||||
with pytest.raises(AssertionError):
|
||||
utils.poly_intersection(0, 1)
|
||||
|
||||
# test non-overlapping polygons
|
||||
|
||||
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
||||
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
|
||||
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
||||
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
||||
points4 = [0.5, 0, 1.5, 0, 1.5, 1, 0.5, 1]
|
||||
poly = utils.points2polygon(points)
|
||||
poly1 = utils.points2polygon(points1)
|
||||
poly2 = utils.points2polygon(points2)
|
||||
poly3 = utils.points2polygon(points3)
|
||||
poly4 = utils.points2polygon(points4)
|
||||
|
||||
area_inters = utils.poly_intersection(poly, poly1)
|
||||
|
||||
assert area_inters == 0
|
||||
|
||||
# test overlapping polygons
|
||||
area_inters = utils.poly_intersection(poly, poly)
|
||||
assert area_inters == 1
|
||||
area_inters = utils.poly_intersection(poly, poly4)
|
||||
assert area_inters == 0.5
|
||||
|
||||
# test invalid polygons
|
||||
assert utils.poly_intersection(poly2, poly2) == 0
|
||||
assert utils.poly_intersection(poly3, poly3, invalid_ret=1) == 1
|
||||
# The return value depends on the implementation of the package
|
||||
assert utils.poly_intersection(poly3, poly3, invalid_ret=None) == 0.25
|
||||
|
||||
# test poly return
|
||||
_, poly = utils.poly_intersection(poly, poly4, return_poly=True)
|
||||
assert isinstance(poly, Polygon)
|
||||
_, poly = utils.poly_intersection(
|
||||
poly3, poly3, invalid_ret=None, return_poly=True)
|
||||
assert isinstance(poly, Polygon)
|
||||
_, poly = utils.poly_intersection(
|
||||
poly2, poly3, invalid_ret=1, return_poly=True)
|
||||
assert poly is None
|
||||
|
||||
|
||||
def test_poly_union():
|
||||
|
||||
# test unsupported type
|
||||
with pytest.raises(AssertionError):
|
||||
utils.poly_union(0, 1)
|
||||
|
||||
# test non-overlapping polygons
|
||||
|
||||
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
||||
points1 = [2, 2, 2, 3, 3, 3, 3, 2]
|
||||
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
||||
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
||||
points4 = [0.5, 0.5, 1, 0, 1, 1, 0.5, 0.5]
|
||||
poly = utils.points2polygon(points)
|
||||
poly1 = utils.points2polygon(points1)
|
||||
poly2 = utils.points2polygon(points2)
|
||||
poly3 = utils.points2polygon(points3)
|
||||
poly4 = utils.points2polygon(points4)
|
||||
|
||||
assert utils.poly_union(poly, poly1) == 2
|
||||
|
||||
# test overlapping polygons
|
||||
assert utils.poly_union(poly, poly) == 1
|
||||
|
||||
# test invalid polygons
|
||||
assert utils.poly_union(poly2, poly2) == 0
|
||||
assert utils.poly_union(poly3, poly3, invalid_ret=1) == 1
|
||||
|
||||
# The return value depends on the implementation of the package
|
||||
assert utils.poly_union(poly3, poly3, invalid_ret=None) == 0.25
|
||||
assert utils.poly_union(poly2, poly3) == 0.25
|
||||
assert utils.poly_union(poly3, poly4) == 0.5
|
||||
|
||||
# test poly return
|
||||
_, poly = utils.poly_union(poly, poly1, return_poly=True)
|
||||
assert isinstance(poly, MultiPolygon)
|
||||
_, poly = utils.poly_union(poly3, poly3, return_poly=True)
|
||||
assert isinstance(poly, Polygon)
|
||||
_, poly = utils.poly_union(poly2, poly3, invalid_ret=0, return_poly=True)
|
||||
assert poly is None
|
||||
|
||||
|
||||
def test_poly_iou():
|
||||
|
||||
# test unsupported type
|
||||
with pytest.raises(AssertionError):
|
||||
utils.poly_iou([1], [2])
|
||||
|
||||
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
||||
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
|
||||
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
||||
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
||||
|
||||
poly = utils.points2polygon(points)
|
||||
poly1 = utils.points2polygon(points1)
|
||||
poly2 = utils.points2polygon(points2)
|
||||
poly3 = utils.points2polygon(points3)
|
||||
|
||||
assert utils.poly_iou(poly, poly1) == 0
|
||||
|
||||
# test overlapping polygons
|
||||
assert utils.poly_iou(poly, poly) == 1
|
||||
|
||||
# test invalid polygons
|
||||
assert utils.poly_iou(poly2, poly2) == 0
|
||||
assert utils.poly_iou(poly3, poly3, zero_division=1) == 1
|
||||
assert utils.poly_iou(poly2, poly3) == 0
|
||||
|
||||
|
||||
def test_boundary_iou():
|
||||
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
||||
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import InstanceData
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.metrics import HmeanIOUMetric
|
||||
|
||||
|
||||
class TestHmeanIOU(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Create dummy test data.
|
||||
|
||||
We denote the polygons as the following.
|
||||
gt_polys: gt_a, gt_b, gt_c, gt_d_ignored
|
||||
pred_polys: pred_a, pred_b, pred_c, pred_d
|
||||
|
||||
There are two pairs of matches: (gt_a, pred_a) and (gt_b, pred_b),
|
||||
because the IoU > threshold.
|
||||
|
||||
gt_c and pred_c do not match any of the polygons.
|
||||
|
||||
pred_d is ignored in the recall computation since it overlaps
|
||||
gt_d_ignored and the precision > ignore_precision_thr.
|
||||
"""
|
||||
# prepare gt
|
||||
self.gt = [{
|
||||
'data_sample': {
|
||||
'instances': [{
|
||||
'polygon': [0, 0, 1, 0, 1, 1, 0, 1],
|
||||
'ignore': False
|
||||
}, {
|
||||
'polygon': [2, 0, 3, 0, 3, 1, 2, 1],
|
||||
'ignore': False
|
||||
}, {
|
||||
'polygon': [10, 0, 11, 0, 11, 1, 10, 1],
|
||||
'ignore': False
|
||||
}, {
|
||||
'polygon': [1, 0, 2, 0, 2, 1, 1, 1],
|
||||
'ignore': True
|
||||
}]
|
||||
}
|
||||
}, {
|
||||
'data_sample': {
|
||||
'instances': [{
|
||||
'polygon': [0, 0, 1, 0, 1, 1, 0, 1],
|
||||
'ignore': False
|
||||
}],
|
||||
}
|
||||
}]
|
||||
|
||||
# prepare pred
|
||||
pred_data_sample = TextDetDataSample()
|
||||
pred_data_sample.pred_instances = InstanceData()
|
||||
pred_data_sample.pred_instances.polygons = [
|
||||
torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1]),
|
||||
torch.FloatTensor([2, 0.1, 3, 0.1, 3, 1.1, 2, 1.1]),
|
||||
torch.FloatTensor([1, 1, 2, 1, 2, 2, 1, 2]),
|
||||
torch.FloatTensor([1, -0.5, 2, -0.5, 2, 0.5, 1, 0.5]),
|
||||
]
|
||||
pred_data_sample.pred_instances.scores = torch.FloatTensor(
|
||||
[1, 1, 1, 0.001])
|
||||
predictions = [pred_data_sample.to_dict()]
|
||||
|
||||
pred_data_sample = TextDetDataSample()
|
||||
pred_data_sample.pred_instances = InstanceData()
|
||||
pred_data_sample.pred_instances.polygons = [
|
||||
torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1])
|
||||
]
|
||||
pred_data_sample.pred_instances.scores = torch.FloatTensor([0.8])
|
||||
predictions.append(pred_data_sample.to_dict())
|
||||
|
||||
self.predictions = predictions
|
||||
|
||||
def test_hmean_iou(self):
|
||||
|
||||
metric = HmeanIOUMetric(prefix='mmocr')
|
||||
metric.process(self.gt, self.predictions)
|
||||
eval_results = metric.evaluate(size=2)
|
||||
|
||||
precision = 3 / 4
|
||||
recall = 3 / 4
|
||||
hmean = 2 * precision * recall / (precision + recall)
|
||||
target_result = {
|
||||
'mmocr/precision': precision,
|
||||
'mmocr/recall': recall,
|
||||
'mmocr/hmean': hmean
|
||||
}
|
||||
self.assertDictEqual(target_result, eval_results)
|
||||
|
||||
def test_compute_metrics(self):
|
||||
# Test different strategies
|
||||
fake_results = [
|
||||
dict(
|
||||
iou_metric=np.array([[1, 1], [1, 0]]),
|
||||
pred_scores=np.array([1., 1.]))
|
||||
]
|
||||
|
||||
# Vanilla
|
||||
metric = HmeanIOUMetric(strategy='vanilla')
|
||||
eval_results = metric.compute_metrics(fake_results)
|
||||
target_result = {'precision': 0.5, 'recall': 0.5, 'hmean': 0.5}
|
||||
self.assertDictEqual(target_result, eval_results)
|
||||
|
||||
# Max matching
|
||||
metric = HmeanIOUMetric(strategy='max_matching')
|
||||
eval_results = metric.compute_metrics(fake_results)
|
||||
target_result = {'precision': 1, 'recall': 1, 'hmean': 1}
|
||||
self.assertDictEqual(target_result, eval_results)
|
|
@ -3,8 +3,11 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from shapely.geometry import MultiPolygon, Polygon
|
||||
|
||||
from mmocr.utils import (crop_polygon, poly2bbox, rescale_polygon,
|
||||
from mmocr.utils import (crop_polygon, poly2bbox, poly2shapely,
|
||||
poly_intersection, poly_iou, poly_make_valid,
|
||||
poly_union, polys2shapely, rescale_polygon,
|
||||
rescale_polygons)
|
||||
|
||||
|
||||
|
@ -93,3 +96,163 @@ class TestPolygonUtils(unittest.TestCase):
|
|||
# test tensor
|
||||
polygon = torch.Tensor([0, 0, 1, 0, 1, 1, 0, 1])
|
||||
self.assertTrue(np.all(poly2bbox(polygon) == np.array([0, 0, 1, 1])))
|
||||
|
||||
def test_poly2shapely(self):
|
||||
polygon = Polygon([[0, 0], [1, 0], [1, 1], [0, 1]])
|
||||
# test np.array
|
||||
poly = np.array([0, 0, 1, 0, 1, 1, 0, 1])
|
||||
self.assertEqual(poly2shapely(poly), polygon)
|
||||
# test list
|
||||
poly = [0, 0, 1, 0, 1, 1, 0, 1]
|
||||
self.assertEqual(poly2shapely(poly), polygon)
|
||||
# test tensor
|
||||
poly = torch.Tensor([0, 0, 1, 0, 1, 1, 0, 1])
|
||||
self.assertEqual(poly2shapely(poly), polygon)
|
||||
# test invalid
|
||||
poly = [0, 0, 1]
|
||||
with self.assertRaises(AssertionError):
|
||||
poly2shapely(poly)
|
||||
poly = [0, 0, 1, 0, 1, 1, 0, 1, 1]
|
||||
with self.assertRaises(AssertionError):
|
||||
poly2shapely(poly)
|
||||
|
||||
def test_polys2shapely(self):
|
||||
polygons = [
|
||||
Polygon([[0, 0], [1, 0], [1, 1], [0, 1]]),
|
||||
Polygon([[1, 0], [1, 1], [0, 1], [0, 0]])
|
||||
]
|
||||
# test np.array
|
||||
polys = np.array([[0, 0, 1, 0, 1, 1, 0, 1], [1, 0, 1, 1, 0, 1, 0, 0]])
|
||||
self.assertEqual(polys2shapely(polys), polygons)
|
||||
# test list
|
||||
polys = [[0, 0, 1, 0, 1, 1, 0, 1], [1, 0, 1, 1, 0, 1, 0, 0]]
|
||||
self.assertEqual(polys2shapely(polys), polygons)
|
||||
# test tensor
|
||||
polys = torch.Tensor([[0, 0, 1, 0, 1, 1, 0, 1],
|
||||
[1, 0, 1, 1, 0, 1, 0, 0]])
|
||||
self.assertEqual(polys2shapely(polys), polygons)
|
||||
# test invalid
|
||||
polys = [0, 0, 1]
|
||||
with self.assertRaises(TypeError):
|
||||
polys2shapely(polys)
|
||||
polys = [0, 0, 1, 0, 1, 1, 0, 1, 1]
|
||||
with self.assertRaises(TypeError):
|
||||
polys2shapely(polys)
|
||||
|
||||
def test_poly_make_valid(self):
|
||||
poly = Polygon([[0, 0], [1, 1], [1, 0], [0, 1]])
|
||||
self.assertFalse(poly.is_valid)
|
||||
poly = poly_make_valid(poly)
|
||||
self.assertTrue(poly.is_valid)
|
||||
# invalid input
|
||||
with self.assertRaises(AssertionError):
|
||||
poly_make_valid([0, 0, 1, 1, 1, 0, 0, 1])
|
||||
|
||||
def test_poly_intersection(self):
|
||||
|
||||
# test unsupported type
|
||||
with self.assertRaises(AssertionError):
|
||||
poly_intersection(0, 1)
|
||||
|
||||
# test non-overlapping polygons
|
||||
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
||||
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
|
||||
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
||||
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
||||
points4 = [0.5, 0, 1.5, 0, 1.5, 1, 0.5, 1]
|
||||
poly = poly2shapely(points)
|
||||
poly1 = poly2shapely(points1)
|
||||
poly2 = poly2shapely(points2)
|
||||
poly3 = poly2shapely(points3)
|
||||
poly4 = poly2shapely(points4)
|
||||
|
||||
area_inters = poly_intersection(poly, poly1)
|
||||
self.assertEqual(area_inters, 0.)
|
||||
|
||||
# test overlapping polygons
|
||||
area_inters = poly_intersection(poly, poly)
|
||||
self.assertEqual(area_inters, 1)
|
||||
area_inters = poly_intersection(poly, poly4)
|
||||
self.assertEqual(area_inters, 0.5)
|
||||
|
||||
# test invalid polygons
|
||||
self.assertEqual(poly_intersection(poly2, poly2), 0)
|
||||
self.assertEqual(poly_intersection(poly3, poly3, invalid_ret=1), 1)
|
||||
self.assertEqual(
|
||||
poly_intersection(poly3, poly3, invalid_ret=None), 0.25)
|
||||
|
||||
# test poly return
|
||||
_, poly = poly_intersection(poly, poly4, return_poly=True)
|
||||
self.assertTrue(isinstance(poly, Polygon))
|
||||
_, poly = poly_intersection(
|
||||
poly3, poly3, invalid_ret=None, return_poly=True)
|
||||
self.assertTrue(isinstance(poly, Polygon))
|
||||
_, poly = poly_intersection(
|
||||
poly2, poly3, invalid_ret=1, return_poly=True)
|
||||
self.assertTrue(poly is None)
|
||||
|
||||
def test_poly_union(self):
|
||||
|
||||
# test unsupported type
|
||||
with self.assertRaises(AssertionError):
|
||||
poly_union(0, 1)
|
||||
|
||||
# test non-overlapping polygons
|
||||
|
||||
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
||||
points1 = [2, 2, 2, 3, 3, 3, 3, 2]
|
||||
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
||||
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
||||
points4 = [0.5, 0.5, 1, 0, 1, 1, 0.5, 0.5]
|
||||
poly = poly2shapely(points)
|
||||
poly1 = poly2shapely(points1)
|
||||
poly2 = poly2shapely(points2)
|
||||
poly3 = poly2shapely(points3)
|
||||
poly4 = poly2shapely(points4)
|
||||
|
||||
assert poly_union(poly, poly1) == 2
|
||||
|
||||
# test overlapping polygons
|
||||
assert poly_union(poly, poly) == 1
|
||||
|
||||
# test invalid polygons
|
||||
self.assertEqual(poly_union(poly2, poly2), 0)
|
||||
self.assertEqual(poly_union(poly3, poly3, invalid_ret=1), 1)
|
||||
|
||||
# The return value depends on the implementation of the package
|
||||
self.assertEqual(poly_union(poly3, poly3, invalid_ret=None), 0.25)
|
||||
self.assertEqual(poly_union(poly2, poly3), 0.25)
|
||||
self.assertEqual(poly_union(poly3, poly4), 0.5)
|
||||
|
||||
# test poly return
|
||||
_, poly = poly_union(poly, poly1, return_poly=True)
|
||||
self.assertTrue(isinstance(poly, MultiPolygon))
|
||||
_, poly = poly_union(poly3, poly3, return_poly=True)
|
||||
self.assertTrue(isinstance(poly, Polygon))
|
||||
_, poly = poly_union(poly2, poly3, invalid_ret=0, return_poly=True)
|
||||
self.assertTrue(poly is None)
|
||||
|
||||
def test_poly_iou(self):
|
||||
# test unsupported type
|
||||
with self.assertRaises(AssertionError):
|
||||
poly_iou([1], [2])
|
||||
|
||||
points = [0, 0, 0, 1, 1, 1, 1, 0]
|
||||
points1 = [10, 20, 30, 40, 50, 60, 70, 80]
|
||||
points2 = [0, 0, 0, 0, 0, 0, 0, 0] # Invalid polygon
|
||||
points3 = [0, 0, 0, 1, 1, 0, 1, 1] # Self-intersected polygon
|
||||
|
||||
poly = poly2shapely(points)
|
||||
poly1 = poly2shapely(points1)
|
||||
poly2 = poly2shapely(points2)
|
||||
poly3 = poly2shapely(points3)
|
||||
|
||||
self.assertEqual(poly_iou(poly, poly1), 0)
|
||||
|
||||
# test overlapping polygons
|
||||
self.assertEqual(poly_iou(poly, poly), 1)
|
||||
|
||||
# test invalid polygons
|
||||
self.assertEqual(poly_iou(poly2, poly2), 0)
|
||||
self.assertEqual(poly_iou(poly3, poly3, zero_division=1), 1)
|
||||
self.assertEqual(poly_iou(poly2, poly3), 0)
|
||||
|
|
Loading…
Reference in New Issue