mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] Refactor WordMetric and CharMetric
parent
4c9d14a6e7
commit
f173cd3543
|
@ -3,9 +3,5 @@ from .hmean import eval_hmean
|
|||
from .hmean_ic13 import eval_hmean_ic13
|
||||
from .kie_metric import compute_f1_score
|
||||
from .ner_metric import eval_ner_f1
|
||||
from .ocr_metric import eval_ocr_metric
|
||||
|
||||
__all__ = [
|
||||
'eval_hmean_ic13', 'eval_ocr_metric', 'eval_hmean', 'compute_f1_score',
|
||||
'eval_ner_f1'
|
||||
]
|
||||
__all__ = ['eval_hmean_ic13', 'eval_hmean', 'compute_f1_score', 'eval_ner_f1']
|
||||
|
|
|
@ -1,165 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from rapidfuzz.distance import Levenshtein
|
||||
|
||||
from mmocr.utils import is_type_list
|
||||
|
||||
|
||||
def cal_true_positive_char(pred, gt):
|
||||
"""Calculate correct character number in prediction.
|
||||
|
||||
Args:
|
||||
pred (str): Prediction text.
|
||||
gt (str): Ground truth text.
|
||||
|
||||
Returns:
|
||||
true_positive_char_num (int): The true positive number.
|
||||
"""
|
||||
|
||||
all_opt = SequenceMatcher(None, pred, gt)
|
||||
true_positive_char_num = 0
|
||||
for opt, _, _, s2, e2 in all_opt.get_opcodes():
|
||||
if opt == 'equal':
|
||||
true_positive_char_num += (e2 - s2)
|
||||
else:
|
||||
pass
|
||||
return true_positive_char_num
|
||||
|
||||
|
||||
def count_matches(pred_texts, gt_texts):
|
||||
"""Count the various match number for metric calculation.
|
||||
|
||||
Args:
|
||||
pred_texts (list[str]): Predicted text string.
|
||||
gt_texts (list[str]): Ground truth text string.
|
||||
|
||||
Returns:
|
||||
match_res: (dict[str: int]): Match number used for
|
||||
metric calculation.
|
||||
"""
|
||||
match_res = {
|
||||
'gt_char_num': 0,
|
||||
'pred_char_num': 0,
|
||||
'true_positive_char_num': 0,
|
||||
'gt_word_num': 0,
|
||||
'match_word_num': 0,
|
||||
'match_word_ignore_case': 0,
|
||||
'match_word_ignore_case_symbol': 0
|
||||
}
|
||||
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
|
||||
norm_ed_sum = 0.0
|
||||
for pred_text, gt_text in zip(pred_texts, gt_texts):
|
||||
if gt_text == pred_text:
|
||||
match_res['match_word_num'] += 1
|
||||
gt_text_lower = gt_text.lower()
|
||||
pred_text_lower = pred_text.lower()
|
||||
if gt_text_lower == pred_text_lower:
|
||||
match_res['match_word_ignore_case'] += 1
|
||||
gt_text_lower_ignore = comp.sub('', gt_text_lower)
|
||||
pred_text_lower_ignore = comp.sub('', pred_text_lower)
|
||||
if gt_text_lower_ignore == pred_text_lower_ignore:
|
||||
match_res['match_word_ignore_case_symbol'] += 1
|
||||
match_res['gt_word_num'] += 1
|
||||
|
||||
norm_ed_sum += Levenshtein.normalized_distance(pred_text_lower_ignore,
|
||||
gt_text_lower_ignore)
|
||||
|
||||
# number to calculate char level recall & precision
|
||||
match_res['gt_char_num'] += len(gt_text_lower_ignore)
|
||||
match_res['pred_char_num'] += len(pred_text_lower_ignore)
|
||||
true_positive_char_num = cal_true_positive_char(
|
||||
pred_text_lower_ignore, gt_text_lower_ignore)
|
||||
match_res['true_positive_char_num'] += true_positive_char_num
|
||||
|
||||
normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
|
||||
match_res['ned'] = normalized_edit_distance
|
||||
|
||||
return match_res
|
||||
|
||||
|
||||
def eval_ocr_metric(pred_texts, gt_texts, metric='acc'):
|
||||
"""Evaluate the text recognition performance with metric: word accuracy and
|
||||
1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
|
||||
|
||||
Args:
|
||||
pred_texts (list[str]): Text strings of prediction.
|
||||
gt_texts (list[str]): Text strings of ground truth.
|
||||
metric (str | list[str]): Metric(s) to be evaluated. Options are:
|
||||
|
||||
- 'word_acc': Accuracy at word level.
|
||||
- 'word_acc_ignore_case': Accuracy at word level, ignoring letter
|
||||
case.
|
||||
- 'word_acc_ignore_case_symbol': Accuracy at word level, ignoring
|
||||
letter case and symbol. (Default metric for academic evaluation)
|
||||
- 'char_recall': Recall at character level, ignoring
|
||||
letter case and symbol.
|
||||
- 'char_precision': Precision at character level, ignoring
|
||||
letter case and symbol.
|
||||
- 'one_minus_ned': 1 - normalized_edit_distance
|
||||
|
||||
In particular, if ``metric == 'acc'``, results on all metrics above
|
||||
will be reported.
|
||||
|
||||
Returns:
|
||||
dict{str: float}: Result dict for text recognition, keys could be some
|
||||
of the following: ['word_acc', 'word_acc_ignore_case',
|
||||
'word_acc_ignore_case_symbol', 'char_recall', 'char_precision',
|
||||
'1-N.E.D'].
|
||||
"""
|
||||
assert isinstance(pred_texts, list)
|
||||
assert isinstance(gt_texts, list)
|
||||
assert len(pred_texts) == len(gt_texts)
|
||||
|
||||
assert isinstance(metric, str) or is_type_list(metric, str)
|
||||
if metric == 'acc' or metric == ['acc']:
|
||||
metric = [
|
||||
'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
|
||||
'char_recall', 'char_precision', 'one_minus_ned'
|
||||
]
|
||||
metric = {metric} if isinstance(metric, str) else set(metric)
|
||||
|
||||
supported_metrics = {
|
||||
'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol',
|
||||
'char_recall', 'char_precision', 'one_minus_ned'
|
||||
}
|
||||
assert metric.issubset(supported_metrics)
|
||||
|
||||
match_res = count_matches(pred_texts, gt_texts)
|
||||
eps = 1e-8
|
||||
eval_res = {}
|
||||
|
||||
if 'char_recall' in metric:
|
||||
char_recall = 1.0 * match_res['true_positive_char_num'] / (
|
||||
eps + match_res['gt_char_num'])
|
||||
eval_res['char_recall'] = char_recall
|
||||
|
||||
if 'char_precision' in metric:
|
||||
char_precision = 1.0 * match_res['true_positive_char_num'] / (
|
||||
eps + match_res['pred_char_num'])
|
||||
eval_res['char_precision'] = char_precision
|
||||
|
||||
if 'word_acc' in metric:
|
||||
word_acc = 1.0 * match_res['match_word_num'] / (
|
||||
eps + match_res['gt_word_num'])
|
||||
eval_res['word_acc'] = word_acc
|
||||
|
||||
if 'word_acc_ignore_case' in metric:
|
||||
word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / (
|
||||
eps + match_res['gt_word_num'])
|
||||
eval_res['word_acc_ignore_case'] = word_acc_ignore_case
|
||||
|
||||
if 'word_acc_ignore_case_symbol' in metric:
|
||||
word_acc_ignore_case_symbol = 1.0 * match_res[
|
||||
'match_word_ignore_case_symbol'] / (
|
||||
eps + match_res['gt_word_num'])
|
||||
eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol
|
||||
|
||||
if 'one_minus_ned' in metric:
|
||||
eval_res['1-N.E.D'] = 1.0 - match_res['ned']
|
||||
|
||||
for key, value in eval_res.items():
|
||||
eval_res[key] = float(f'{value:.4f}')
|
||||
|
||||
return eval_res
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hmean_iou_metric import HmeanIOUMetric
|
||||
from .recog_metric import CharMetric, OneMinusNEDMetric, WordMetric
|
||||
|
||||
__all__ = ['HmeanIOUMetric']
|
||||
__all__ = ['WordMetric', 'CharMetric', 'OneMinusNEDMetric', 'HmeanIOUMetric']
|
||||
|
|
|
@ -0,0 +1,292 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from rapidfuzz import string_metric
|
||||
|
||||
from mmocr.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class WordMetric(BaseMetric):
|
||||
"""Word metrics for text recognition task.
|
||||
|
||||
Args:
|
||||
mode (str or list[str]): Options are:
|
||||
- 'exact': Accuracy at word level.
|
||||
- 'ignore_case': Accuracy at word level, ignoring letter
|
||||
case.
|
||||
- 'ignore_case_symbol': Accuracy at word level, ignoring
|
||||
letter case and symbol. (Default metric for academic evaluation)
|
||||
If mode is a list, then metrics in mode will be calculated
|
||||
separately. Defaults to 'ignore_case_symbol'
|
||||
valid_symbol (str): Valid characters.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
default_prefix: Optional[str] = 'recog'
|
||||
|
||||
def __init__(self,
|
||||
mode: Union[str, Sequence[str]] = 'ignore_case_symbol',
|
||||
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device, prefix)
|
||||
self.valid_symbol = re.compile(valid_symbol)
|
||||
if isinstance(mode, str):
|
||||
mode = [mode]
|
||||
assert mmcv.is_seq_of(mode, str)
|
||||
assert set(mode).issubset(
|
||||
{'exact', 'ignore_case', 'ignore_case_symbol'})
|
||||
self.mode = set(mode)
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
"""Process one batch of predictions. The processed results should be
|
||||
stored in ``self.results``, which will be used to compute the metrics
|
||||
when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of gts.
|
||||
predictions (Sequence[Dict]): A batch of outputs from the model.
|
||||
"""
|
||||
match_num = 0
|
||||
match_ignore_case_num = 0
|
||||
match_ignore_case_symbol_num = 0
|
||||
for gt, pred in zip(data_batch, predictions):
|
||||
pred_text = pred.get('pred_text').get('item')
|
||||
gt_text = gt.get('data_sample').get('instances')[0].get('text')
|
||||
if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode:
|
||||
pred_text_lower = pred_text.lower()
|
||||
gt_text_lower = gt_text.lower()
|
||||
if 'ignore_case_symbol' in self.mode:
|
||||
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower)
|
||||
pred_text_lower_ignore = self.valid_symbol.sub(
|
||||
'', pred_text_lower)
|
||||
match_ignore_case_symbol_num +=\
|
||||
gt_text_lower_ignore == pred_text_lower_ignore
|
||||
if 'ignore_case' in self.mode:
|
||||
match_ignore_case_num += pred_text_lower == gt_text_lower
|
||||
if 'exact' in self.mode:
|
||||
match_num += pred_text == gt_text
|
||||
results = dict(
|
||||
match_num=match_num,
|
||||
match_ignore_case_num=match_ignore_case_num,
|
||||
match_ignore_case_symbol_num=match_ignore_case_symbol_num)
|
||||
self.results.append(results)
|
||||
|
||||
def compute_metrics(self, results: Sequence[Dict]) -> Dict:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list[Dict]): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
|
||||
eps = 1e-8
|
||||
eval_res = {}
|
||||
gt_word_num = len(results)
|
||||
if 'exact' in self.mode:
|
||||
match_nums = [result['match_num'] for result in results]
|
||||
match_nums = sum(match_nums)
|
||||
eval_res['word_acc'] = 1.0 * match_nums / (eps + gt_word_num)
|
||||
if 'ignore_case' in self.mode:
|
||||
match_ignore_case_num = [
|
||||
result['match_ignore_case_num'] for result in results
|
||||
]
|
||||
match_ignore_case_num = sum(match_ignore_case_num)
|
||||
eval_res['word_acc_ignore_case'] = 1.0 *\
|
||||
match_ignore_case_num / (eps + gt_word_num)
|
||||
if 'ignore_case_symbol' in self.mode:
|
||||
match_ignore_case_symbol_num = [
|
||||
result['match_ignore_case_symbol_num'] for result in results
|
||||
]
|
||||
match_ignore_case_symbol_num = sum(match_ignore_case_symbol_num)
|
||||
eval_res['word_acc_ignore_case_symbol'] = 1.0 *\
|
||||
match_ignore_case_symbol_num / (eps + gt_word_num)
|
||||
|
||||
for key, value in eval_res.items():
|
||||
eval_res[key] = float(f'{value:.4f}')
|
||||
return eval_res
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class CharMetric(BaseMetric):
|
||||
"""Character metrics for text recognition task.
|
||||
|
||||
Args:
|
||||
valid_symbol (str): Valid characters.
|
||||
Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
default_prefix: Optional[str] = 'recog'
|
||||
|
||||
def __init__(self,
|
||||
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device, prefix)
|
||||
self.valid_symbol = re.compile(valid_symbol)
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
"""Process one batch of predictions. The processed results should be
|
||||
stored in ``self.results``, which will be used to compute the metrics
|
||||
when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of gts.
|
||||
predictions (Sequence[Dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for gt, pred in zip(data_batch, predictions):
|
||||
pred_text = pred.get('pred_text').get('item')
|
||||
gt_text = gt.get('data_sample').get('instances')[0].get('text')
|
||||
gt_text_lower = gt_text.lower()
|
||||
pred_text_lower = pred_text.lower()
|
||||
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower)
|
||||
pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower)
|
||||
# number to calculate char level recall & precision
|
||||
result = dict(
|
||||
gt_char_num=len(gt_text_lower_ignore),
|
||||
pred_char_num=len(pred_text_lower_ignore),
|
||||
true_positive_char_num=self._cal_true_positive_char(
|
||||
pred_text_lower_ignore, gt_text_lower_ignore))
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: Sequence[Dict]) -> Dict:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list[Dict]): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
gt_char_num = [result['gt_char_num'] for result in results]
|
||||
pred_char_num = [result['pred_char_num'] for result in results]
|
||||
true_positive_char_num = [
|
||||
result['true_positive_char_num'] for result in results
|
||||
]
|
||||
gt_char_num = sum(gt_char_num)
|
||||
pred_char_num = sum(pred_char_num)
|
||||
true_positive_char_num = sum(true_positive_char_num)
|
||||
|
||||
eps = 1e-8
|
||||
char_recall = 1.0 * true_positive_char_num / (eps + gt_char_num)
|
||||
char_precision = 1.0 * true_positive_char_num / (eps + pred_char_num)
|
||||
eval_res = {}
|
||||
eval_res['char_recall'] = char_recall
|
||||
eval_res['char_precision'] = char_precision
|
||||
|
||||
for key, value in eval_res.items():
|
||||
eval_res[key] = float(f'{value:.4f}')
|
||||
return eval_res
|
||||
|
||||
def _cal_true_positive_char(self, pred: str, gt: str) -> int:
|
||||
"""Calculate correct character number in prediction.
|
||||
|
||||
Args:
|
||||
pred (str): Prediction text.
|
||||
gt (str): Ground truth text.
|
||||
|
||||
Returns:
|
||||
true_positive_char_num (int): The true positive number.
|
||||
"""
|
||||
|
||||
all_opt = SequenceMatcher(None, pred, gt)
|
||||
true_positive_char_num = 0
|
||||
for opt, _, _, s2, e2 in all_opt.get_opcodes():
|
||||
if opt == 'equal':
|
||||
true_positive_char_num += (e2 - s2)
|
||||
else:
|
||||
pass
|
||||
return true_positive_char_num
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class OneMinusNEDMetric(BaseMetric):
|
||||
"""One minus NED metric for text recognition task.
|
||||
|
||||
Args:
|
||||
valid_symbol (str): Valid characters
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None
|
||||
"""
|
||||
default_prefix: Optional[str] = 'recog'
|
||||
|
||||
def __init__(self,
|
||||
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device, prefix)
|
||||
self.valid_symbol = re.compile(valid_symbol)
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
"""Process one batch of predictions. The processed results should be
|
||||
stored in ``self.results``, which will be used to compute the metrics
|
||||
when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of gts.
|
||||
predictions (Sequence[Dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for gt, pred in zip(data_batch, predictions):
|
||||
pred_text = pred.get('pred_text').get('item')
|
||||
gt_text = gt.get('data_sample').get('instances')[0].get('text')
|
||||
gt_text_lower = gt_text.lower()
|
||||
pred_text_lower = pred_text.lower()
|
||||
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower)
|
||||
pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower)
|
||||
edit_dist = string_metric.levenshtein(pred_text_lower_ignore,
|
||||
gt_text_lower_ignore)
|
||||
norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore),
|
||||
len(pred_text_lower_ignore))
|
||||
result = dict(norm_ed=norm_ed)
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: Sequence[Dict]) -> Dict:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list[Dict]): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
|
||||
gt_word_num = len(results)
|
||||
norm_ed = [result['norm_ed'] for result in results]
|
||||
norm_ed_sum = sum(norm_ed)
|
||||
normalized_edit_distance = norm_ed_sum / max(1, gt_word_num)
|
||||
eval_res = {}
|
||||
eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance
|
||||
for key, value in eval_res.items():
|
||||
eval_res[key] = float(f'{value:.4f}')
|
||||
return eval_res
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmocr.core import TextRecogDataSample
|
||||
from mmocr.metrics import CharMetric, OneMinusNEDMetric, WordMetric
|
||||
|
||||
|
||||
class TestWordMetric(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# prepare gt hello HELLO $HELLO$
|
||||
gt1 = {
|
||||
'data_sample': {
|
||||
'height': 32,
|
||||
'width': 100,
|
||||
'instances': [{
|
||||
'text': 'hello'
|
||||
}]
|
||||
}
|
||||
}
|
||||
gt2 = {
|
||||
'data_sample': {
|
||||
'height': 32,
|
||||
'width': 100,
|
||||
'instances': [{
|
||||
'text': 'HELLO'
|
||||
}]
|
||||
}
|
||||
}
|
||||
gt3 = {
|
||||
'data_sample': {
|
||||
'height': 32,
|
||||
'width': 100,
|
||||
'instances': [{
|
||||
'text': '$HELLO$'
|
||||
}]
|
||||
}
|
||||
}
|
||||
self.gt = [gt1, gt2, gt3]
|
||||
# prepare pred
|
||||
pred_data_sample = TextRecogDataSample()
|
||||
pred_text = LabelData()
|
||||
pred_text.item = 'hello'
|
||||
pred_data_sample.pred_text = pred_text
|
||||
|
||||
self.pred = [
|
||||
pred_data_sample,
|
||||
copy.deepcopy(pred_data_sample),
|
||||
copy.deepcopy(pred_data_sample),
|
||||
]
|
||||
|
||||
def test_word_acc_metric(self):
|
||||
metric = WordMetric(mode='exact')
|
||||
metric.process(self.gt, self.pred)
|
||||
eval_res = metric.evaluate(size=3)
|
||||
self.assertTrue(eval_res['recog/word_acc'], 1. / 3)
|
||||
|
||||
def test_word_acc_ignore_case_metric(self):
|
||||
metric = WordMetric(mode='ignore_case')
|
||||
metric.process(self.gt, self.pred)
|
||||
eval_res = metric.evaluate(size=3)
|
||||
self.assertTrue(eval_res['recog/word_acc_ignore_case'], 2. / 3)
|
||||
|
||||
def test_word_acc_ignore_case_symbol_metric(self):
|
||||
metric = WordMetric(mode='ignore_case_symbol')
|
||||
metric.process(self.gt, self.pred)
|
||||
eval_res = metric.evaluate(size=3)
|
||||
self.assertTrue(eval_res['recog/word_acc_ignore_case_symbol'], 1.0)
|
||||
|
||||
def test_all_metric(self):
|
||||
metric = WordMetric(
|
||||
mode=['exact', 'ignore_case', 'ignore_case_symbol'])
|
||||
metric.process(self.gt, self.pred)
|
||||
eval_res = metric.evaluate(size=3)
|
||||
self.assertTrue(eval_res['recog/word_acc'], 1. / 3)
|
||||
self.assertTrue(eval_res['recog/word_acc_ignore_case'], 2. / 3)
|
||||
self.assertTrue(eval_res['recog/word_acc_ignore_case_symbol'], 1.0)
|
||||
|
||||
|
||||
class TestCharMetric(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# prepare gt
|
||||
gt1 = {
|
||||
'data_sample': {
|
||||
'height': 32,
|
||||
'width': 100,
|
||||
'instances': [{
|
||||
'text': 'hello'
|
||||
}]
|
||||
}
|
||||
}
|
||||
gt2 = {
|
||||
'data_sample': {
|
||||
'height': 32,
|
||||
'width': 100,
|
||||
'instances': [{
|
||||
'text': 'HELLO'
|
||||
}]
|
||||
}
|
||||
}
|
||||
self.gt = [gt1, gt2]
|
||||
# prepare pred
|
||||
pred_data_sample1 = TextRecogDataSample()
|
||||
pred_text = LabelData()
|
||||
pred_text.item = 'helL'
|
||||
pred_data_sample1.pred_text = pred_text
|
||||
|
||||
pred_data_sample2 = TextRecogDataSample()
|
||||
pred_text = LabelData()
|
||||
pred_text.item = 'HEL'
|
||||
pred_data_sample2.pred_text = pred_text
|
||||
|
||||
self.pred = [pred_data_sample1, pred_data_sample2]
|
||||
|
||||
def test_char_recall_precision_metric(self):
|
||||
metric = CharMetric()
|
||||
metric.process(self.gt, self.pred)
|
||||
eval_res = metric.evaluate(size=2)
|
||||
self.assertTrue(eval_res['recog/char_recall'], 0.8)
|
||||
self.assertTrue(eval_res['recog/char_precision'], 0.7)
|
||||
|
||||
|
||||
class TestOneMinusNED(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# prepare gt
|
||||
gt1 = {
|
||||
'data_sample': {
|
||||
'height': 32,
|
||||
'width': 100,
|
||||
'instances': [{
|
||||
'text': 'hello'
|
||||
}]
|
||||
}
|
||||
}
|
||||
gt2 = {
|
||||
'data_sample': {
|
||||
'height': 32,
|
||||
'width': 100,
|
||||
'instances': [{
|
||||
'text': 'HELLO'
|
||||
}]
|
||||
}
|
||||
}
|
||||
self.gt = [gt1, gt2]
|
||||
# prepare pred
|
||||
pred_data_sample1 = TextRecogDataSample()
|
||||
pred_text = LabelData()
|
||||
pred_text.item = 'pred_helL'
|
||||
pred_data_sample1.pred_text = pred_text
|
||||
|
||||
pred_data_sample2 = TextRecogDataSample()
|
||||
pred_text = LabelData()
|
||||
pred_text.item = 'HEL'
|
||||
pred_data_sample2.pred_text = pred_text
|
||||
|
||||
self.pred = [pred_data_sample1, pred_data_sample2]
|
||||
|
||||
def test_one_minus_ned_metric(self):
|
||||
metric = OneMinusNEDMetric()
|
||||
metric.process(self.gt, self.pred)
|
||||
eval_res = metric.evaluate(size=2)
|
||||
self.assertTrue(eval_res['recog/1-N.E.D'], 0.7)
|
Loading…
Reference in New Issue