mirror of https://github.com/open-mmlab/mmocr.git
316 lines
14 KiB
Python
316 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import glob
|
|
import os.path as osp
|
|
import re
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.evaluator import BaseMetric
|
|
from mmengine.logging import MMLogger
|
|
from rapidfuzz.distance import Levenshtein
|
|
from shapely.geometry import Point
|
|
|
|
from mmocr.registry import METRICS
|
|
|
|
# TODO: CTW1500 read pair
|
|
|
|
|
|
@METRICS.register_module()
|
|
class E2EPointMetric(BaseMetric):
|
|
"""Point metric for textspotting. Proposed in SPTS.
|
|
|
|
Args:
|
|
text_score_thrs (dict): Best text score threshold searching
|
|
space. Defaults to dict(start=0.8, stop=1, step=0.01).
|
|
word_spotting (bool): Whether to work in word spotting mode. Defaults
|
|
to False.
|
|
lexicon_path (str, optional): Lexicon path for word spotting, which
|
|
points to a lexicon file or a directory. Defaults to None.
|
|
lexicon_mapping (tuple, optional): The rule to map test image name to
|
|
its corresponding lexicon file. Only effective when lexicon path
|
|
is a directory. Defaults to ('(.*).jpg', r'\1.txt').
|
|
pair_path (str, optional): Pair path for word spotting, which points
|
|
to a pair file or a directory. Defaults to None.
|
|
pair_mapping (tuple, optional): The rule to map test image name to
|
|
its corresponding pair file. Only effective when pair path is a
|
|
directory. Defaults to ('(.*).jpg', r'\1.txt').
|
|
match_dist_thr (float, optional): Matching distance threshold for
|
|
word spotting. Defaults to None.
|
|
collect_device (str): Device name used for collecting results from
|
|
different ranks during distributed training. Must be 'cpu' or
|
|
'gpu'. Defaults to 'cpu'.
|
|
prefix (str, optional): The prefix that will be added in the metric
|
|
names to disambiguate homonymous metrics of different evaluators.
|
|
If prefix is not provided in the argument, self.default_prefix
|
|
will be used instead. Defaults to None
|
|
"""
|
|
default_prefix: Optional[str] = 'e2e_icdar'
|
|
|
|
def __init__(self,
|
|
text_score_thrs: Dict = dict(start=0.8, stop=1, step=0.01),
|
|
word_spotting: bool = False,
|
|
lexicon_path: Optional[str] = None,
|
|
lexicon_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'),
|
|
pair_path: Optional[str] = None,
|
|
pair_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'),
|
|
match_dist_thr: Optional[float] = None,
|
|
collect_device: str = 'cpu',
|
|
prefix: Optional[str] = None) -> None:
|
|
super().__init__(collect_device=collect_device, prefix=prefix)
|
|
self.text_score_thrs = np.arange(**text_score_thrs)
|
|
self.word_spotting = word_spotting
|
|
self.match_dist_thr = match_dist_thr
|
|
if lexicon_path:
|
|
self.lexicon_mapping = lexicon_mapping
|
|
self.pair_mapping = pair_mapping
|
|
self.lexicons = self._read_lexicon(lexicon_path)
|
|
self.pairs = self._read_pair(pair_path)
|
|
|
|
def _read_lexicon(self, lexicon_path: str) -> List[str]:
|
|
if lexicon_path.endswith('.txt'):
|
|
lexicon = open(lexicon_path, 'r').read().splitlines()
|
|
lexicon = [ele.strip() for ele in lexicon]
|
|
else:
|
|
lexicon = {}
|
|
for file in glob.glob(osp.join(lexicon_path, '*.txt')):
|
|
basename = osp.basename(file)
|
|
lexicon[basename] = self._read_lexicon(file)
|
|
return lexicon
|
|
|
|
def _read_pair(self, pair_path: str) -> Dict[str, str]:
|
|
pairs = {}
|
|
if pair_path.endswith('.txt'):
|
|
pair_lines = open(pair_path, 'r').read().splitlines()
|
|
for line in pair_lines:
|
|
line = line.strip()
|
|
word = line.split(' ')[0].upper()
|
|
word_gt = line[len(word) + 1:]
|
|
pairs[word] = word_gt
|
|
else:
|
|
for file in glob.glob(osp.join(pair_path, '*.txt')):
|
|
basename = osp.basename(file)
|
|
pairs[basename] = self._read_pair(file)
|
|
return pairs
|
|
|
|
def poly_center(self, poly_pts):
|
|
poly_pts = np.array(poly_pts).reshape(-1, 2)
|
|
return poly_pts.mean(0)
|
|
|
|
def process(self, data_batch: Sequence[Dict],
|
|
data_samples: Sequence[Dict]) -> None:
|
|
"""Process one batch of data samples and predictions. The processed
|
|
results should be stored in ``self.results``, which will be used to
|
|
compute the metrics when all batches have been processed.
|
|
|
|
Args:
|
|
data_batch (Sequence[Dict]): A batch of data from dataloader.
|
|
data_samples (Sequence[Dict]): A batch of outputs from
|
|
the model.
|
|
"""
|
|
for data_sample in data_samples:
|
|
|
|
pred_instances = data_sample.get('pred_instances')
|
|
pred_points = pred_instances.get('points')
|
|
text_scores = pred_instances.get('text_scores')
|
|
if isinstance(text_scores, torch.Tensor):
|
|
text_scores = text_scores.cpu().numpy()
|
|
text_scores = np.array(text_scores, dtype=np.float32)
|
|
pred_texts = pred_instances.get('texts')
|
|
|
|
gt_instances = data_sample.get('gt_instances')
|
|
gt_polys = gt_instances.get('polygons')
|
|
gt_ignore_flags = gt_instances.get('ignored')
|
|
gt_texts = gt_instances.get('texts')
|
|
if isinstance(gt_ignore_flags, torch.Tensor):
|
|
gt_ignore_flags = gt_ignore_flags.cpu().numpy()
|
|
|
|
gt_points = [self.poly_center(poly) for poly in gt_polys]
|
|
if self.word_spotting:
|
|
gt_ignore_flags, gt_texts = self._word_spotting_filter(
|
|
gt_ignore_flags, gt_texts)
|
|
|
|
pred_ignore_flags = text_scores < self.text_score_thrs.min()
|
|
text_scores = text_scores[~pred_ignore_flags]
|
|
pred_texts = self._get_true_elements(pred_texts,
|
|
~pred_ignore_flags)
|
|
pred_points = self._get_true_elements(pred_points,
|
|
~pred_ignore_flags)
|
|
|
|
result = dict(
|
|
# reserved for image-level lexcions
|
|
gt_img_name=osp.basename(data_sample.get('img_path', '')),
|
|
text_scores=text_scores,
|
|
pred_points=pred_points,
|
|
gt_points=gt_points,
|
|
pred_texts=pred_texts,
|
|
gt_texts=gt_texts,
|
|
gt_ignore_flags=gt_ignore_flags)
|
|
self.results.append(result)
|
|
|
|
def _get_true_elements(self, array: List, flags: np.ndarray) -> List:
|
|
return [array[i] for i in self._true_indexes(flags)]
|
|
|
|
def compute_metrics(self, results: List[Dict]) -> Dict:
|
|
"""Compute the metrics from processed results.
|
|
|
|
Args:
|
|
results (list[dict]): The processed results of each batch.
|
|
|
|
Returns:
|
|
dict: The computed metrics. The keys are the names of the metrics,
|
|
and the values are corresponding results.
|
|
"""
|
|
logger: MMLogger = MMLogger.get_current_instance()
|
|
|
|
best_eval_results = dict(hmean=-1)
|
|
|
|
num_thres = len(self.text_score_thrs)
|
|
num_preds = np.zeros(
|
|
num_thres, dtype=int) # the number of points actually predicted
|
|
num_tp = np.zeros(num_thres, dtype=int) # number of true positives
|
|
num_gts = np.zeros(num_thres, dtype=int) # number of valid gts
|
|
|
|
for result in results:
|
|
text_scores = result['text_scores']
|
|
pred_points = result['pred_points']
|
|
gt_points = result['gt_points']
|
|
gt_texts = result['gt_texts']
|
|
pred_texts = result['pred_texts']
|
|
gt_ignore_flags = result['gt_ignore_flags']
|
|
gt_img_name = result['gt_img_name']
|
|
|
|
# Correct the words with lexicon
|
|
pred_dist_flags = np.zeros(len(pred_texts), dtype=bool)
|
|
if hasattr(self, 'lexicons'):
|
|
for i, pred_text in enumerate(pred_texts):
|
|
# If it's an image-level lexicon
|
|
if isinstance(self.lexicons, dict):
|
|
lexicon_name = self._map_img_name(
|
|
gt_img_name, self.lexicon_mapping)
|
|
pair_name = self._map_img_name(gt_img_name,
|
|
self.pair_mapping)
|
|
pred_texts[i], match_dist = self._match_word(
|
|
pred_text, self.lexicons[lexicon_name],
|
|
self.pairs[pair_name])
|
|
else:
|
|
pred_texts[i], match_dist = self._match_word(
|
|
pred_text, self.lexicons, self.pairs)
|
|
if (self.match_dist_thr
|
|
and match_dist >= self.match_dist_thr):
|
|
# won't even count this as a prediction
|
|
pred_dist_flags[i] = True
|
|
|
|
# Filter out predictions by IoU threshold
|
|
for i, text_score_thr in enumerate(self.text_score_thrs):
|
|
pred_ignore_flags = pred_dist_flags | (
|
|
text_scores < text_score_thr)
|
|
filtered_pred_texts = self._get_true_elements(
|
|
pred_texts, ~pred_ignore_flags)
|
|
filtered_pred_points = self._get_true_elements(
|
|
pred_points, ~pred_ignore_flags)
|
|
gt_matched = np.zeros(len(gt_texts), dtype=bool)
|
|
num_gt = len(gt_texts) - np.sum(gt_ignore_flags)
|
|
if num_gt == 0:
|
|
continue
|
|
num_gts[i] += num_gt
|
|
|
|
for pred_text, pred_point in zip(filtered_pred_texts,
|
|
filtered_pred_points):
|
|
dists = [
|
|
Point(pred_point).distance(Point(gt_point))
|
|
for gt_point in gt_points
|
|
]
|
|
min_idx = np.argmin(dists)
|
|
if gt_texts[min_idx] == '###' or gt_ignore_flags[min_idx]:
|
|
continue
|
|
if not gt_matched[min_idx] and (
|
|
pred_text.upper() == gt_texts[min_idx].upper()):
|
|
gt_matched[min_idx] = True
|
|
num_tp[i] += 1
|
|
num_preds[i] += 1
|
|
|
|
for i, text_score_thr in enumerate(self.text_score_thrs):
|
|
if num_preds[i] == 0 or num_tp[i] == 0:
|
|
recall, precision, hmean = 0, 0, 0
|
|
else:
|
|
recall = num_tp[i] / num_gts[i]
|
|
precision = num_tp[i] / num_preds[i]
|
|
hmean = 2 * recall * precision / (recall + precision)
|
|
eval_results = dict(
|
|
precision=precision, recall=recall, hmean=hmean)
|
|
logger.info(f'text score threshold: {text_score_thr:.2f}, '
|
|
f'recall: {eval_results["recall"]:.4f}, '
|
|
f'precision: {eval_results["precision"]:.4f}, '
|
|
f'hmean: {eval_results["hmean"]:.4f}\n')
|
|
if eval_results['hmean'] > best_eval_results['hmean']:
|
|
best_eval_results = eval_results
|
|
return best_eval_results
|
|
|
|
def _map_img_name(self, img_name: str, mapping: Tuple[str, str]) -> str:
|
|
"""Map the image name to the another one based on mapping."""
|
|
return re.sub(mapping[0], mapping[1], img_name)
|
|
|
|
def _true_indexes(self, array: np.ndarray) -> np.ndarray:
|
|
"""Get indexes of True elements from a 1D boolean array."""
|
|
return np.where(array)[0]
|
|
|
|
def _word_spotting_filter(self, gt_ignore_flags: np.ndarray,
|
|
gt_texts: List[str]
|
|
) -> Tuple[np.ndarray, List[str]]:
|
|
"""Filter out gt instances that cannot be in a valid dictionary, and do
|
|
some simple preprocessing to texts."""
|
|
|
|
for i in range(len(gt_texts)):
|
|
if gt_ignore_flags[i]:
|
|
continue
|
|
text = gt_texts[i]
|
|
if text[-2:] in ["'s", "'S"]:
|
|
text = text[:-2]
|
|
text = text.strip('-')
|
|
for char in "'!?.:,*\"()·[]/":
|
|
text = text.replace(char, ' ')
|
|
text = text.strip()
|
|
gt_ignore_flags[i] = not self._include_in_dict(text)
|
|
if not gt_ignore_flags[i]:
|
|
gt_texts[i] = text
|
|
|
|
return gt_ignore_flags, gt_texts
|
|
|
|
def _include_in_dict(self, text: str) -> bool:
|
|
"""Check if the text could be in a valid dictionary."""
|
|
if len(text) != len(text.replace(' ', '')) or len(text) < 3:
|
|
return False
|
|
not_allowed = '×÷·'
|
|
valid_ranges = [(ord(u'a'), ord(u'z')), (ord(u'A'), ord(u'Z')),
|
|
(ord(u'À'), ord(u'ƿ')), (ord(u'DŽ'), ord(u'ɿ')),
|
|
(ord(u'Ά'), ord(u'Ͽ')), (ord(u'-'), ord(u'-'))]
|
|
for char in text:
|
|
code = ord(char)
|
|
if (not_allowed.find(char) != -1):
|
|
return False
|
|
valid = any(code >= r[0] and code <= r[1] for r in valid_ranges)
|
|
if not valid:
|
|
return False
|
|
return True
|
|
|
|
def _match_word(self,
|
|
text: str,
|
|
lexicons: List[str],
|
|
pairs: Optional[Dict[str, str]] = None) -> Tuple[str, int]:
|
|
"""Match the text with the lexicons and pairs."""
|
|
text = text.upper()
|
|
matched_word = ''
|
|
matched_dist = 100
|
|
for lexicon in lexicons:
|
|
lexicon = lexicon.upper()
|
|
norm_dist = Levenshtein.distance(text, lexicon)
|
|
norm_dist = Levenshtein.normalized_distance(text, lexicon)
|
|
if norm_dist < matched_dist:
|
|
matched_dist = norm_dist
|
|
if pairs:
|
|
matched_word = pairs[lexicon]
|
|
else:
|
|
matched_word = lexicon
|
|
return matched_word, matched_dist
|