[Refactor] Refactor WordMetric and CharMetric

pull/1178/head
jiangqing.vendor 2022-05-26 15:42:16 +00:00 committed by gaotongxiao
parent 4c9d14a6e7
commit f173cd3543
5 changed files with 462 additions and 171 deletions

View File

@ -3,9 +3,5 @@ from .hmean import eval_hmean
from .hmean_ic13 import eval_hmean_ic13
from .kie_metric import compute_f1_score
from .ner_metric import eval_ner_f1
from .ocr_metric import eval_ocr_metric
__all__ = [
'eval_hmean_ic13', 'eval_ocr_metric', 'eval_hmean', 'compute_f1_score',
'eval_ner_f1'
]
__all__ = ['eval_hmean_ic13', 'eval_hmean', 'compute_f1_score', 'eval_ner_f1']

View File

@ -1,165 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from difflib import SequenceMatcher
from rapidfuzz.distance import Levenshtein
from mmocr.utils import is_type_list
def cal_true_positive_char(pred, gt):
"""Calculate correct character number in prediction.
Args:
pred (str): Prediction text.
gt (str): Ground truth text.
Returns:
true_positive_char_num (int): The true positive number.
"""
all_opt = SequenceMatcher(None, pred, gt)
true_positive_char_num = 0
for opt, _, _, s2, e2 in all_opt.get_opcodes():
if opt == 'equal':
true_positive_char_num += (e2 - s2)
else:
pass
return true_positive_char_num
def count_matches(pred_texts, gt_texts):
"""Count the various match number for metric calculation.
Args:
pred_texts (list[str]): Predicted text string.
gt_texts (list[str]): Ground truth text string.
Returns:
match_res: (dict[str: int]): Match number used for
metric calculation.
"""
match_res = {
'gt_char_num': 0,
'pred_char_num': 0,
'true_positive_char_num': 0,
'gt_word_num': 0,
'match_word_num': 0,
'match_word_ignore_case': 0,
'match_word_ignore_case_symbol': 0
}
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
norm_ed_sum = 0.0
for pred_text, gt_text in zip(pred_texts, gt_texts):
if gt_text == pred_text:
match_res['match_word_num'] += 1
gt_text_lower = gt_text.lower()
pred_text_lower = pred_text.lower()
if gt_text_lower == pred_text_lower:
match_res['match_word_ignore_case'] += 1
gt_text_lower_ignore = comp.sub('', gt_text_lower)
pred_text_lower_ignore = comp.sub('', pred_text_lower)
if gt_text_lower_ignore == pred_text_lower_ignore:
match_res['match_word_ignore_case_symbol'] += 1
match_res['gt_word_num'] += 1
norm_ed_sum += Levenshtein.normalized_distance(pred_text_lower_ignore,
gt_text_lower_ignore)
# number to calculate char level recall & precision
match_res['gt_char_num'] += len(gt_text_lower_ignore)
match_res['pred_char_num'] += len(pred_text_lower_ignore)
true_positive_char_num = cal_true_positive_char(
pred_text_lower_ignore, gt_text_lower_ignore)
match_res['true_positive_char_num'] += true_positive_char_num
normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
match_res['ned'] = normalized_edit_distance
return match_res
def eval_ocr_metric(pred_texts, gt_texts, metric='acc'):
"""Evaluate the text recognition performance with metric: word accuracy and
1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
Args:
pred_texts (list[str]): Text strings of prediction.
gt_texts (list[str]): Text strings of ground truth.
metric (str | list[str]): Metric(s) to be evaluated. Options are:
- 'word_acc': Accuracy at word level.
- 'word_acc_ignore_case': Accuracy at word level, ignoring letter
case.
- 'word_acc_ignore_case_symbol': Accuracy at word level, ignoring
letter case and symbol. (Default metric for academic evaluation)
- 'char_recall': Recall at character level, ignoring
letter case and symbol.
- 'char_precision': Precision at character level, ignoring
letter case and symbol.
- 'one_minus_ned': 1 - normalized_edit_distance
In particular, if ``metric == 'acc'``, results on all metrics above
will be reported.
Returns:
dict{str: float}: Result dict for text recognition, keys could be some
of the following: ['word_acc', 'word_acc_ignore_case',
'word_acc_ignore_case_symbol', 'char_recall', 'char_precision',
'1-N.E.D'].
"""
assert isinstance(pred_texts, list)
assert isinstance(gt_texts, list)
assert len(pred_texts) == len(gt_texts)
assert isinstance(metric, str) or is_type_list(metric, str)
if metric == 'acc' or metric == ['acc']:
metric = [
'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
'char_recall', 'char_precision', 'one_minus_ned'
]
metric = {metric} if isinstance(metric, str) else set(metric)
supported_metrics = {
'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
'char_recall', 'char_precision', 'one_minus_ned'
}
assert metric.issubset(supported_metrics)
match_res = count_matches(pred_texts, gt_texts)
eps = 1e-8
eval_res = {}
if 'char_recall' in metric:
char_recall = 1.0 * match_res['true_positive_char_num'] / (
eps + match_res['gt_char_num'])
eval_res['char_recall'] = char_recall
if 'char_precision' in metric:
char_precision = 1.0 * match_res['true_positive_char_num'] / (
eps + match_res['pred_char_num'])
eval_res['char_precision'] = char_precision
if 'word_acc' in metric:
word_acc = 1.0 * match_res['match_word_num'] / (
eps + match_res['gt_word_num'])
eval_res['word_acc'] = word_acc
if 'word_acc_ignore_case' in metric:
word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / (
eps + match_res['gt_word_num'])
eval_res['word_acc_ignore_case'] = word_acc_ignore_case
if 'word_acc_ignore_case_symbol' in metric:
word_acc_ignore_case_symbol = 1.0 * match_res[
'match_word_ignore_case_symbol'] / (
eps + match_res['gt_word_num'])
eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol
if 'one_minus_ned' in metric:
eval_res['1-N.E.D'] = 1.0 - match_res['ned']
for key, value in eval_res.items():
eval_res[key] = float(f'{value:.4f}')
return eval_res

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hmean_iou_metric import HmeanIOUMetric
from .recog_metric import CharMetric, OneMinusNEDMetric, WordMetric
__all__ = ['HmeanIOUMetric']
__all__ = ['WordMetric', 'CharMetric', 'OneMinusNEDMetric', 'HmeanIOUMetric']

View File

@ -0,0 +1,292 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from difflib import SequenceMatcher
from typing import Dict, Optional, Sequence, Union
import mmcv
from mmengine.evaluator import BaseMetric
from rapidfuzz import string_metric
from mmocr.registry import METRICS
@METRICS.register_module()
class WordMetric(BaseMetric):
"""Word metrics for text recognition task.
Args:
mode (str or list[str]): Options are:
- 'exact': Accuracy at word level.
- 'ignore_case': Accuracy at word level, ignoring letter
case.
- 'ignore_case_symbol': Accuracy at word level, ignoring
letter case and symbol. (Default metric for academic evaluation)
If mode is a list, then metrics in mode will be calculated
separately. Defaults to 'ignore_case_symbol'
valid_symbol (str): Valid characters.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
"""
default_prefix: Optional[str] = 'recog'
def __init__(self,
mode: Union[str, Sequence[str]] = 'ignore_case_symbol',
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device, prefix)
self.valid_symbol = re.compile(valid_symbol)
if isinstance(mode, str):
mode = [mode]
assert mmcv.is_seq_of(mode, str)
assert set(mode).issubset(
{'exact', 'ignore_case', 'ignore_case_symbol'})
self.mode = set(mode)
def process(self, data_batch: Sequence[Dict],
predictions: Sequence[Dict]) -> None:
"""Process one batch of predictions. The processed results should be
stored in ``self.results``, which will be used to compute the metrics
when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of gts.
predictions (Sequence[Dict]): A batch of outputs from the model.
"""
match_num = 0
match_ignore_case_num = 0
match_ignore_case_symbol_num = 0
for gt, pred in zip(data_batch, predictions):
pred_text = pred.get('pred_text').get('item')
gt_text = gt.get('data_sample').get('instances')[0].get('text')
if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode:
pred_text_lower = pred_text.lower()
gt_text_lower = gt_text.lower()
if 'ignore_case_symbol' in self.mode:
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower)
pred_text_lower_ignore = self.valid_symbol.sub(
'', pred_text_lower)
match_ignore_case_symbol_num +=\
gt_text_lower_ignore == pred_text_lower_ignore
if 'ignore_case' in self.mode:
match_ignore_case_num += pred_text_lower == gt_text_lower
if 'exact' in self.mode:
match_num += pred_text == gt_text
results = dict(
match_num=match_num,
match_ignore_case_num=match_ignore_case_num,
match_ignore_case_symbol_num=match_ignore_case_symbol_num)
self.results.append(results)
def compute_metrics(self, results: Sequence[Dict]) -> Dict:
"""Compute the metrics from processed results.
Args:
results (list[Dict]): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
eps = 1e-8
eval_res = {}
gt_word_num = len(results)
if 'exact' in self.mode:
match_nums = [result['match_num'] for result in results]
match_nums = sum(match_nums)
eval_res['word_acc'] = 1.0 * match_nums / (eps + gt_word_num)
if 'ignore_case' in self.mode:
match_ignore_case_num = [
result['match_ignore_case_num'] for result in results
]
match_ignore_case_num = sum(match_ignore_case_num)
eval_res['word_acc_ignore_case'] = 1.0 *\
match_ignore_case_num / (eps + gt_word_num)
if 'ignore_case_symbol' in self.mode:
match_ignore_case_symbol_num = [
result['match_ignore_case_symbol_num'] for result in results
]
match_ignore_case_symbol_num = sum(match_ignore_case_symbol_num)
eval_res['word_acc_ignore_case_symbol'] = 1.0 *\
match_ignore_case_symbol_num / (eps + gt_word_num)
for key, value in eval_res.items():
eval_res[key] = float(f'{value:.4f}')
return eval_res
@METRICS.register_module()
class CharMetric(BaseMetric):
"""Character metrics for text recognition task.
Args:
valid_symbol (str): Valid characters.
Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
"""
default_prefix: Optional[str] = 'recog'
def __init__(self,
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device, prefix)
self.valid_symbol = re.compile(valid_symbol)
def process(self, data_batch: Sequence[Dict],
predictions: Sequence[Dict]) -> None:
"""Process one batch of predictions. The processed results should be
stored in ``self.results``, which will be used to compute the metrics
when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of gts.
predictions (Sequence[Dict]): A batch of outputs from the model.
"""
for gt, pred in zip(data_batch, predictions):
pred_text = pred.get('pred_text').get('item')
gt_text = gt.get('data_sample').get('instances')[0].get('text')
gt_text_lower = gt_text.lower()
pred_text_lower = pred_text.lower()
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower)
pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower)
# number to calculate char level recall & precision
result = dict(
gt_char_num=len(gt_text_lower_ignore),
pred_char_num=len(pred_text_lower_ignore),
true_positive_char_num=self._cal_true_positive_char(
pred_text_lower_ignore, gt_text_lower_ignore))
self.results.append(result)
def compute_metrics(self, results: Sequence[Dict]) -> Dict:
"""Compute the metrics from processed results.
Args:
results (list[Dict]): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the
metrics, and the values are corresponding results.
"""
gt_char_num = [result['gt_char_num'] for result in results]
pred_char_num = [result['pred_char_num'] for result in results]
true_positive_char_num = [
result['true_positive_char_num'] for result in results
]
gt_char_num = sum(gt_char_num)
pred_char_num = sum(pred_char_num)
true_positive_char_num = sum(true_positive_char_num)
eps = 1e-8
char_recall = 1.0 * true_positive_char_num / (eps + gt_char_num)
char_precision = 1.0 * true_positive_char_num / (eps + pred_char_num)
eval_res = {}
eval_res['char_recall'] = char_recall
eval_res['char_precision'] = char_precision
for key, value in eval_res.items():
eval_res[key] = float(f'{value:.4f}')
return eval_res
def _cal_true_positive_char(self, pred: str, gt: str) -> int:
"""Calculate correct character number in prediction.
Args:
pred (str): Prediction text.
gt (str): Ground truth text.
Returns:
true_positive_char_num (int): The true positive number.
"""
all_opt = SequenceMatcher(None, pred, gt)
true_positive_char_num = 0
for opt, _, _, s2, e2 in all_opt.get_opcodes():
if opt == 'equal':
true_positive_char_num += (e2 - s2)
else:
pass
return true_positive_char_num
@METRICS.register_module()
class OneMinusNEDMetric(BaseMetric):
"""One minus NED metric for text recognition task.
Args:
valid_symbol (str): Valid characters
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None
"""
default_prefix: Optional[str] = 'recog'
def __init__(self,
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device, prefix)
self.valid_symbol = re.compile(valid_symbol)
def process(self, data_batch: Sequence[Dict],
predictions: Sequence[Dict]) -> None:
"""Process one batch of predictions. The processed results should be
stored in ``self.results``, which will be used to compute the metrics
when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of gts.
predictions (Sequence[Dict]): A batch of outputs from the model.
"""
for gt, pred in zip(data_batch, predictions):
pred_text = pred.get('pred_text').get('item')
gt_text = gt.get('data_sample').get('instances')[0].get('text')
gt_text_lower = gt_text.lower()
pred_text_lower = pred_text.lower()
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower)
pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower)
edit_dist = string_metric.levenshtein(pred_text_lower_ignore,
gt_text_lower_ignore)
norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore),
len(pred_text_lower_ignore))
result = dict(norm_ed=norm_ed)
self.results.append(result)
def compute_metrics(self, results: Sequence[Dict]) -> Dict:
"""Compute the metrics from processed results.
Args:
results (list[Dict]): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the
metrics, and the values are corresponding results.
"""
gt_word_num = len(results)
norm_ed = [result['norm_ed'] for result in results]
norm_ed_sum = sum(norm_ed)
normalized_edit_distance = norm_ed_sum / max(1, gt_word_num)
eval_res = {}
eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance
for key, value in eval_res.items():
eval_res[key] = float(f'{value:.4f}')
return eval_res

View File

@ -0,0 +1,167 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest
from mmengine.data import LabelData
from mmocr.core import TextRecogDataSample
from mmocr.metrics import CharMetric, OneMinusNEDMetric, WordMetric
class TestWordMetric(unittest.TestCase):
def setUp(self):
# prepare gt hello HELLO $HELLO$
gt1 = {
'data_sample': {
'height': 32,
'width': 100,
'instances': [{
'text': 'hello'
}]
}
}
gt2 = {
'data_sample': {
'height': 32,
'width': 100,
'instances': [{
'text': 'HELLO'
}]
}
}
gt3 = {
'data_sample': {
'height': 32,
'width': 100,
'instances': [{
'text': '$HELLO$'
}]
}
}
self.gt = [gt1, gt2, gt3]
# prepare pred
pred_data_sample = TextRecogDataSample()
pred_text = LabelData()
pred_text.item = 'hello'
pred_data_sample.pred_text = pred_text
self.pred = [
pred_data_sample,
copy.deepcopy(pred_data_sample),
copy.deepcopy(pred_data_sample),
]
def test_word_acc_metric(self):
metric = WordMetric(mode='exact')
metric.process(self.gt, self.pred)
eval_res = metric.evaluate(size=3)
self.assertTrue(eval_res['recog/word_acc'], 1. / 3)
def test_word_acc_ignore_case_metric(self):
metric = WordMetric(mode='ignore_case')
metric.process(self.gt, self.pred)
eval_res = metric.evaluate(size=3)
self.assertTrue(eval_res['recog/word_acc_ignore_case'], 2. / 3)
def test_word_acc_ignore_case_symbol_metric(self):
metric = WordMetric(mode='ignore_case_symbol')
metric.process(self.gt, self.pred)
eval_res = metric.evaluate(size=3)
self.assertTrue(eval_res['recog/word_acc_ignore_case_symbol'], 1.0)
def test_all_metric(self):
metric = WordMetric(
mode=['exact', 'ignore_case', 'ignore_case_symbol'])
metric.process(self.gt, self.pred)
eval_res = metric.evaluate(size=3)
self.assertTrue(eval_res['recog/word_acc'], 1. / 3)
self.assertTrue(eval_res['recog/word_acc_ignore_case'], 2. / 3)
self.assertTrue(eval_res['recog/word_acc_ignore_case_symbol'], 1.0)
class TestCharMetric(unittest.TestCase):
def setUp(self):
# prepare gt
gt1 = {
'data_sample': {
'height': 32,
'width': 100,
'instances': [{
'text': 'hello'
}]
}
}
gt2 = {
'data_sample': {
'height': 32,
'width': 100,
'instances': [{
'text': 'HELLO'
}]
}
}
self.gt = [gt1, gt2]
# prepare pred
pred_data_sample1 = TextRecogDataSample()
pred_text = LabelData()
pred_text.item = 'helL'
pred_data_sample1.pred_text = pred_text
pred_data_sample2 = TextRecogDataSample()
pred_text = LabelData()
pred_text.item = 'HEL'
pred_data_sample2.pred_text = pred_text
self.pred = [pred_data_sample1, pred_data_sample2]
def test_char_recall_precision_metric(self):
metric = CharMetric()
metric.process(self.gt, self.pred)
eval_res = metric.evaluate(size=2)
self.assertTrue(eval_res['recog/char_recall'], 0.8)
self.assertTrue(eval_res['recog/char_precision'], 0.7)
class TestOneMinusNED(unittest.TestCase):
def setUp(self):
# prepare gt
gt1 = {
'data_sample': {
'height': 32,
'width': 100,
'instances': [{
'text': 'hello'
}]
}
}
gt2 = {
'data_sample': {
'height': 32,
'width': 100,
'instances': [{
'text': 'HELLO'
}]
}
}
self.gt = [gt1, gt2]
# prepare pred
pred_data_sample1 = TextRecogDataSample()
pred_text = LabelData()
pred_text.item = 'pred_helL'
pred_data_sample1.pred_text = pred_text
pred_data_sample2 = TextRecogDataSample()
pred_text = LabelData()
pred_text.item = 'HEL'
pred_data_sample2.pred_text = pred_text
self.pred = [pred_data_sample1, pred_data_sample2]
def test_one_minus_ned_metric(self):
metric = OneMinusNEDMetric()
metric.process(self.gt, self.pred)
eval_res = metric.evaluate(size=2)
self.assertTrue(eval_res['recog/1-N.E.D'], 0.7)