diff --git a/mmocr/core/evaluation/__init__.py b/mmocr/core/evaluation/__init__.py index 96f8b8b8..466d8981 100644 --- a/mmocr/core/evaluation/__init__.py +++ b/mmocr/core/evaluation/__init__.py @@ -3,9 +3,5 @@ from .hmean import eval_hmean from .hmean_ic13 import eval_hmean_ic13 from .kie_metric import compute_f1_score from .ner_metric import eval_ner_f1 -from .ocr_metric import eval_ocr_metric -__all__ = [ - 'eval_hmean_ic13', 'eval_ocr_metric', 'eval_hmean', 'compute_f1_score', - 'eval_ner_f1' -] +__all__ = ['eval_hmean_ic13', 'eval_hmean', 'compute_f1_score', 'eval_ner_f1'] diff --git a/mmocr/core/evaluation/ocr_metric.py b/mmocr/core/evaluation/ocr_metric.py deleted file mode 100644 index 7896b1b0..00000000 --- a/mmocr/core/evaluation/ocr_metric.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import re -from difflib import SequenceMatcher - -from rapidfuzz.distance import Levenshtein - -from mmocr.utils import is_type_list - - -def cal_true_positive_char(pred, gt): - """Calculate correct character number in prediction. - - Args: - pred (str): Prediction text. - gt (str): Ground truth text. - - Returns: - true_positive_char_num (int): The true positive number. - """ - - all_opt = SequenceMatcher(None, pred, gt) - true_positive_char_num = 0 - for opt, _, _, s2, e2 in all_opt.get_opcodes(): - if opt == 'equal': - true_positive_char_num += (e2 - s2) - else: - pass - return true_positive_char_num - - -def count_matches(pred_texts, gt_texts): - """Count the various match number for metric calculation. - - Args: - pred_texts (list[str]): Predicted text string. - gt_texts (list[str]): Ground truth text string. - - Returns: - match_res: (dict[str: int]): Match number used for - metric calculation. - """ - match_res = { - 'gt_char_num': 0, - 'pred_char_num': 0, - 'true_positive_char_num': 0, - 'gt_word_num': 0, - 'match_word_num': 0, - 'match_word_ignore_case': 0, - 'match_word_ignore_case_symbol': 0 - } - comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') - norm_ed_sum = 0.0 - for pred_text, gt_text in zip(pred_texts, gt_texts): - if gt_text == pred_text: - match_res['match_word_num'] += 1 - gt_text_lower = gt_text.lower() - pred_text_lower = pred_text.lower() - if gt_text_lower == pred_text_lower: - match_res['match_word_ignore_case'] += 1 - gt_text_lower_ignore = comp.sub('', gt_text_lower) - pred_text_lower_ignore = comp.sub('', pred_text_lower) - if gt_text_lower_ignore == pred_text_lower_ignore: - match_res['match_word_ignore_case_symbol'] += 1 - match_res['gt_word_num'] += 1 - - norm_ed_sum += Levenshtein.normalized_distance(pred_text_lower_ignore, - gt_text_lower_ignore) - - # number to calculate char level recall & precision - match_res['gt_char_num'] += len(gt_text_lower_ignore) - match_res['pred_char_num'] += len(pred_text_lower_ignore) - true_positive_char_num = cal_true_positive_char( - pred_text_lower_ignore, gt_text_lower_ignore) - match_res['true_positive_char_num'] += true_positive_char_num - - normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts)) - match_res['ned'] = normalized_edit_distance - - return match_res - - -def eval_ocr_metric(pred_texts, gt_texts, metric='acc'): - """Evaluate the text recognition performance with metric: word accuracy and - 1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details. - - Args: - pred_texts (list[str]): Text strings of prediction. - gt_texts (list[str]): Text strings of ground truth. - metric (str | list[str]): Metric(s) to be evaluated. Options are: - - - 'word_acc': Accuracy at word level. - - 'word_acc_ignore_case': Accuracy at word level, ignoring letter - case. - - 'word_acc_ignore_case_symbol': Accuracy at word level, ignoring - letter case and symbol. (Default metric for academic evaluation) - - 'char_recall': Recall at character level, ignoring - letter case and symbol. - - 'char_precision': Precision at character level, ignoring - letter case and symbol. - - 'one_minus_ned': 1 - normalized_edit_distance - - In particular, if ``metric == 'acc'``, results on all metrics above - will be reported. - - Returns: - dict{str: float}: Result dict for text recognition, keys could be some - of the following: ['word_acc', 'word_acc_ignore_case', - 'word_acc_ignore_case_symbol', 'char_recall', 'char_precision', - '1-N.E.D']. - """ - assert isinstance(pred_texts, list) - assert isinstance(gt_texts, list) - assert len(pred_texts) == len(gt_texts) - - assert isinstance(metric, str) or is_type_list(metric, str) - if metric == 'acc' or metric == ['acc']: - metric = [ - 'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol', - 'char_recall', 'char_precision', 'one_minus_ned' - ] - metric = {metric} if isinstance(metric, str) else set(metric) - - supported_metrics = { - 'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol', - 'char_recall', 'char_precision', 'one_minus_ned' - } - assert metric.issubset(supported_metrics) - - match_res = count_matches(pred_texts, gt_texts) - eps = 1e-8 - eval_res = {} - - if 'char_recall' in metric: - char_recall = 1.0 * match_res['true_positive_char_num'] / ( - eps + match_res['gt_char_num']) - eval_res['char_recall'] = char_recall - - if 'char_precision' in metric: - char_precision = 1.0 * match_res['true_positive_char_num'] / ( - eps + match_res['pred_char_num']) - eval_res['char_precision'] = char_precision - - if 'word_acc' in metric: - word_acc = 1.0 * match_res['match_word_num'] / ( - eps + match_res['gt_word_num']) - eval_res['word_acc'] = word_acc - - if 'word_acc_ignore_case' in metric: - word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / ( - eps + match_res['gt_word_num']) - eval_res['word_acc_ignore_case'] = word_acc_ignore_case - - if 'word_acc_ignore_case_symbol' in metric: - word_acc_ignore_case_symbol = 1.0 * match_res[ - 'match_word_ignore_case_symbol'] / ( - eps + match_res['gt_word_num']) - eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol - - if 'one_minus_ned' in metric: - eval_res['1-N.E.D'] = 1.0 - match_res['ned'] - - for key, value in eval_res.items(): - eval_res[key] = float(f'{value:.4f}') - - return eval_res diff --git a/mmocr/metrics/__init__.py b/mmocr/metrics/__init__.py index 2268a7d6..bc67fda7 100644 --- a/mmocr/metrics/__init__.py +++ b/mmocr/metrics/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .hmean_iou_metric import HmeanIOUMetric +from .recog_metric import CharMetric, OneMinusNEDMetric, WordMetric -__all__ = ['HmeanIOUMetric'] +__all__ = ['WordMetric', 'CharMetric', 'OneMinusNEDMetric', 'HmeanIOUMetric'] diff --git a/mmocr/metrics/recog_metric.py b/mmocr/metrics/recog_metric.py new file mode 100644 index 00000000..3f68f825 --- /dev/null +++ b/mmocr/metrics/recog_metric.py @@ -0,0 +1,292 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from difflib import SequenceMatcher +from typing import Dict, Optional, Sequence, Union + +import mmcv +from mmengine.evaluator import BaseMetric +from rapidfuzz import string_metric + +from mmocr.registry import METRICS + + +@METRICS.register_module() +class WordMetric(BaseMetric): + """Word metrics for text recognition task. + + Args: + mode (str or list[str]): Options are: + - 'exact': Accuracy at word level. + - 'ignore_case': Accuracy at word level, ignoring letter + case. + - 'ignore_case_symbol': Accuracy at word level, ignoring + letter case and symbol. (Default metric for academic evaluation) + If mode is a list, then metrics in mode will be calculated + separately. Defaults to 'ignore_case_symbol' + valid_symbol (str): Valid characters. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + default_prefix: Optional[str] = 'recog' + + def __init__(self, + mode: Union[str, Sequence[str]] = 'ignore_case_symbol', + valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + self.valid_symbol = re.compile(valid_symbol) + if isinstance(mode, str): + mode = [mode] + assert mmcv.is_seq_of(mode, str) + assert set(mode).issubset( + {'exact', 'ignore_case', 'ignore_case_symbol'}) + self.mode = set(mode) + + def process(self, data_batch: Sequence[Dict], + predictions: Sequence[Dict]) -> None: + """Process one batch of predictions. The processed results should be + stored in ``self.results``, which will be used to compute the metrics + when all batches have been processed. + + Args: + data_batch (Sequence[Dict]): A batch of gts. + predictions (Sequence[Dict]): A batch of outputs from the model. + """ + match_num = 0 + match_ignore_case_num = 0 + match_ignore_case_symbol_num = 0 + for gt, pred in zip(data_batch, predictions): + pred_text = pred.get('pred_text').get('item') + gt_text = gt.get('data_sample').get('instances')[0].get('text') + if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode: + pred_text_lower = pred_text.lower() + gt_text_lower = gt_text.lower() + if 'ignore_case_symbol' in self.mode: + gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) + pred_text_lower_ignore = self.valid_symbol.sub( + '', pred_text_lower) + match_ignore_case_symbol_num +=\ + gt_text_lower_ignore == pred_text_lower_ignore + if 'ignore_case' in self.mode: + match_ignore_case_num += pred_text_lower == gt_text_lower + if 'exact' in self.mode: + match_num += pred_text == gt_text + results = dict( + match_num=match_num, + match_ignore_case_num=match_ignore_case_num, + match_ignore_case_symbol_num=match_ignore_case_symbol_num) + self.results.append(results) + + def compute_metrics(self, results: Sequence[Dict]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[Dict]): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + + eps = 1e-8 + eval_res = {} + gt_word_num = len(results) + if 'exact' in self.mode: + match_nums = [result['match_num'] for result in results] + match_nums = sum(match_nums) + eval_res['word_acc'] = 1.0 * match_nums / (eps + gt_word_num) + if 'ignore_case' in self.mode: + match_ignore_case_num = [ + result['match_ignore_case_num'] for result in results + ] + match_ignore_case_num = sum(match_ignore_case_num) + eval_res['word_acc_ignore_case'] = 1.0 *\ + match_ignore_case_num / (eps + gt_word_num) + if 'ignore_case_symbol' in self.mode: + match_ignore_case_symbol_num = [ + result['match_ignore_case_symbol_num'] for result in results + ] + match_ignore_case_symbol_num = sum(match_ignore_case_symbol_num) + eval_res['word_acc_ignore_case_symbol'] = 1.0 *\ + match_ignore_case_symbol_num / (eps + gt_word_num) + + for key, value in eval_res.items(): + eval_res[key] = float(f'{value:.4f}') + return eval_res + + +@METRICS.register_module() +class CharMetric(BaseMetric): + """Character metrics for text recognition task. + + Args: + valid_symbol (str): Valid characters. + Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]' + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + default_prefix: Optional[str] = 'recog' + + def __init__(self, + valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + self.valid_symbol = re.compile(valid_symbol) + + def process(self, data_batch: Sequence[Dict], + predictions: Sequence[Dict]) -> None: + """Process one batch of predictions. The processed results should be + stored in ``self.results``, which will be used to compute the metrics + when all batches have been processed. + + Args: + data_batch (Sequence[Dict]): A batch of gts. + predictions (Sequence[Dict]): A batch of outputs from the model. + """ + for gt, pred in zip(data_batch, predictions): + pred_text = pred.get('pred_text').get('item') + gt_text = gt.get('data_sample').get('instances')[0].get('text') + gt_text_lower = gt_text.lower() + pred_text_lower = pred_text.lower() + gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) + pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower) + # number to calculate char level recall & precision + result = dict( + gt_char_num=len(gt_text_lower_ignore), + pred_char_num=len(pred_text_lower_ignore), + true_positive_char_num=self._cal_true_positive_char( + pred_text_lower_ignore, gt_text_lower_ignore)) + self.results.append(result) + + def compute_metrics(self, results: Sequence[Dict]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[Dict]): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the + metrics, and the values are corresponding results. + """ + gt_char_num = [result['gt_char_num'] for result in results] + pred_char_num = [result['pred_char_num'] for result in results] + true_positive_char_num = [ + result['true_positive_char_num'] for result in results + ] + gt_char_num = sum(gt_char_num) + pred_char_num = sum(pred_char_num) + true_positive_char_num = sum(true_positive_char_num) + + eps = 1e-8 + char_recall = 1.0 * true_positive_char_num / (eps + gt_char_num) + char_precision = 1.0 * true_positive_char_num / (eps + pred_char_num) + eval_res = {} + eval_res['char_recall'] = char_recall + eval_res['char_precision'] = char_precision + + for key, value in eval_res.items(): + eval_res[key] = float(f'{value:.4f}') + return eval_res + + def _cal_true_positive_char(self, pred: str, gt: str) -> int: + """Calculate correct character number in prediction. + + Args: + pred (str): Prediction text. + gt (str): Ground truth text. + + Returns: + true_positive_char_num (int): The true positive number. + """ + + all_opt = SequenceMatcher(None, pred, gt) + true_positive_char_num = 0 + for opt, _, _, s2, e2 in all_opt.get_opcodes(): + if opt == 'equal': + true_positive_char_num += (e2 - s2) + else: + pass + return true_positive_char_num + + +@METRICS.register_module() +class OneMinusNEDMetric(BaseMetric): + """One minus NED metric for text recognition task. + + Args: + valid_symbol (str): Valid characters + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None + """ + default_prefix: Optional[str] = 'recog' + + def __init__(self, + valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + self.valid_symbol = re.compile(valid_symbol) + + def process(self, data_batch: Sequence[Dict], + predictions: Sequence[Dict]) -> None: + """Process one batch of predictions. The processed results should be + stored in ``self.results``, which will be used to compute the metrics + when all batches have been processed. + + Args: + data_batch (Sequence[Dict]): A batch of gts. + predictions (Sequence[Dict]): A batch of outputs from the model. + """ + for gt, pred in zip(data_batch, predictions): + pred_text = pred.get('pred_text').get('item') + gt_text = gt.get('data_sample').get('instances')[0].get('text') + gt_text_lower = gt_text.lower() + pred_text_lower = pred_text.lower() + gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) + pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower) + edit_dist = string_metric.levenshtein(pred_text_lower_ignore, + gt_text_lower_ignore) + norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore), + len(pred_text_lower_ignore)) + result = dict(norm_ed=norm_ed) + self.results.append(result) + + def compute_metrics(self, results: Sequence[Dict]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[Dict]): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the + metrics, and the values are corresponding results. + """ + + gt_word_num = len(results) + norm_ed = [result['norm_ed'] for result in results] + norm_ed_sum = sum(norm_ed) + normalized_edit_distance = norm_ed_sum / max(1, gt_word_num) + eval_res = {} + eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance + for key, value in eval_res.items(): + eval_res[key] = float(f'{value:.4f}') + return eval_res diff --git a/tests/test_metrics/test_recog_metric.py b/tests/test_metrics/test_recog_metric.py new file mode 100644 index 00000000..b430e9d2 --- /dev/null +++ b/tests/test_metrics/test_recog_metric.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import unittest + +from mmengine.data import LabelData + +from mmocr.core import TextRecogDataSample +from mmocr.metrics import CharMetric, OneMinusNEDMetric, WordMetric + + +class TestWordMetric(unittest.TestCase): + + def setUp(self): + # prepare gt hello HELLO $HELLO$ + gt1 = { + 'data_sample': { + 'height': 32, + 'width': 100, + 'instances': [{ + 'text': 'hello' + }] + } + } + gt2 = { + 'data_sample': { + 'height': 32, + 'width': 100, + 'instances': [{ + 'text': 'HELLO' + }] + } + } + gt3 = { + 'data_sample': { + 'height': 32, + 'width': 100, + 'instances': [{ + 'text': '$HELLO$' + }] + } + } + self.gt = [gt1, gt2, gt3] + # prepare pred + pred_data_sample = TextRecogDataSample() + pred_text = LabelData() + pred_text.item = 'hello' + pred_data_sample.pred_text = pred_text + + self.pred = [ + pred_data_sample, + copy.deepcopy(pred_data_sample), + copy.deepcopy(pred_data_sample), + ] + + def test_word_acc_metric(self): + metric = WordMetric(mode='exact') + metric.process(self.gt, self.pred) + eval_res = metric.evaluate(size=3) + self.assertTrue(eval_res['recog/word_acc'], 1. / 3) + + def test_word_acc_ignore_case_metric(self): + metric = WordMetric(mode='ignore_case') + metric.process(self.gt, self.pred) + eval_res = metric.evaluate(size=3) + self.assertTrue(eval_res['recog/word_acc_ignore_case'], 2. / 3) + + def test_word_acc_ignore_case_symbol_metric(self): + metric = WordMetric(mode='ignore_case_symbol') + metric.process(self.gt, self.pred) + eval_res = metric.evaluate(size=3) + self.assertTrue(eval_res['recog/word_acc_ignore_case_symbol'], 1.0) + + def test_all_metric(self): + metric = WordMetric( + mode=['exact', 'ignore_case', 'ignore_case_symbol']) + metric.process(self.gt, self.pred) + eval_res = metric.evaluate(size=3) + self.assertTrue(eval_res['recog/word_acc'], 1. / 3) + self.assertTrue(eval_res['recog/word_acc_ignore_case'], 2. / 3) + self.assertTrue(eval_res['recog/word_acc_ignore_case_symbol'], 1.0) + + +class TestCharMetric(unittest.TestCase): + + def setUp(self): + # prepare gt + gt1 = { + 'data_sample': { + 'height': 32, + 'width': 100, + 'instances': [{ + 'text': 'hello' + }] + } + } + gt2 = { + 'data_sample': { + 'height': 32, + 'width': 100, + 'instances': [{ + 'text': 'HELLO' + }] + } + } + self.gt = [gt1, gt2] + # prepare pred + pred_data_sample1 = TextRecogDataSample() + pred_text = LabelData() + pred_text.item = 'helL' + pred_data_sample1.pred_text = pred_text + + pred_data_sample2 = TextRecogDataSample() + pred_text = LabelData() + pred_text.item = 'HEL' + pred_data_sample2.pred_text = pred_text + + self.pred = [pred_data_sample1, pred_data_sample2] + + def test_char_recall_precision_metric(self): + metric = CharMetric() + metric.process(self.gt, self.pred) + eval_res = metric.evaluate(size=2) + self.assertTrue(eval_res['recog/char_recall'], 0.8) + self.assertTrue(eval_res['recog/char_precision'], 0.7) + + +class TestOneMinusNED(unittest.TestCase): + + def setUp(self): + # prepare gt + gt1 = { + 'data_sample': { + 'height': 32, + 'width': 100, + 'instances': [{ + 'text': 'hello' + }] + } + } + gt2 = { + 'data_sample': { + 'height': 32, + 'width': 100, + 'instances': [{ + 'text': 'HELLO' + }] + } + } + self.gt = [gt1, gt2] + # prepare pred + pred_data_sample1 = TextRecogDataSample() + pred_text = LabelData() + pred_text.item = 'pred_helL' + pred_data_sample1.pred_text = pred_text + + pred_data_sample2 = TextRecogDataSample() + pred_text = LabelData() + pred_text.item = 'HEL' + pred_data_sample2.pred_text = pred_text + + self.pred = [pred_data_sample1, pred_data_sample2] + + def test_one_minus_ned_metric(self): + metric = OneMinusNEDMetric() + metric.process(self.gt, self.pred) + eval_res = metric.evaluate(size=2) + self.assertTrue(eval_res['recog/1-N.E.D'], 0.7)