Support Infographic VQA dataset and ANLS metric. (#1667)
parent
4f2f3752d9
commit
340d187765
|
@ -41,6 +41,7 @@ if WITH_MULTIMODAL:
|
|||
from .flickr30k_caption import Flickr30kCaption
|
||||
from .flickr30k_retrieval import Flickr30kRetrieval
|
||||
from .gqa_dataset import GQA
|
||||
from .infographic_vqa import InfographicVQA
|
||||
from .iconqa import IconQA
|
||||
from .nocaps import NoCaps
|
||||
from .ocr_vqa import OCRVQA
|
||||
|
@ -55,5 +56,5 @@ if WITH_MULTIMODAL:
|
|||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
|
||||
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
|
||||
'VSR', 'VizWiz', 'OCRVQA', 'IconQA'
|
||||
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA'
|
||||
])
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine.dataset import BaseDataset
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class InfographicVQA(BaseDataset):
|
||||
"""Infographic VQA dataset.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``data_prefix``, ``ann_file``.
|
||||
data_prefix (str): The directory of images.
|
||||
ann_file (str, optional): Annotation file path for training and
|
||||
validation. Defaults to an empty string.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_root: str,
|
||||
data_prefix: str,
|
||||
ann_file: str = '',
|
||||
**kwarg):
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path=data_prefix),
|
||||
ann_file=ann_file,
|
||||
**kwarg,
|
||||
)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
annotations = mmengine.load(self.ann_file)
|
||||
annotations = annotations['data']
|
||||
|
||||
data_list = []
|
||||
for ann in annotations:
|
||||
# ann example
|
||||
# {
|
||||
# "questionId": 98313,
|
||||
# "question": "Which social platform has heavy female audience?",
|
||||
# "image_local_name": "37313.jpeg",
|
||||
# "image_url": "https://xxx.png",
|
||||
# "ocr_output_file": "37313.json",
|
||||
# "answers": [
|
||||
# "pinterest"
|
||||
# ],
|
||||
# "data_split": "val"
|
||||
# }
|
||||
data_info = dict()
|
||||
data_info['question'] = ann['question']
|
||||
data_info['img_path'] = mmengine.join_path(
|
||||
self.data_prefix['img_path'], ann['image_local_name'])
|
||||
if 'answers' in ann.keys(): # test splits do not include gt
|
||||
data_info['gt_answer'] = ann['answers']
|
||||
data_list.append(data_info)
|
||||
|
||||
return data_list
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmpretrain.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ANLS(BaseMetric):
|
||||
"""ANLS metric.
|
||||
|
||||
Compute the Average Normalized Levenshtein Similarity(ANLS).
|
||||
|
||||
Args:
|
||||
threshold (float): ANLS threshold used for determining if the answer
|
||||
has been correctly selected but not properly recognized,
|
||||
or on the contrary, the output is a wrong text selected from the
|
||||
options and given as an answer.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Should be modified according to the
|
||||
`retrieval_type` for unambiguous results. Defaults to TR.
|
||||
"""
|
||||
default_prefix = 'ANLS'
|
||||
|
||||
def __init__(self,
|
||||
threshold: float = 0.5,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
self.threshold = threshold
|
||||
|
||||
def process(self, data_batch, data_samples) -> None:
|
||||
"""Process one batch of data samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for sample in data_samples:
|
||||
gt_answer = sample.get('gt_answer')
|
||||
result = {
|
||||
'pred_answer': sample.get('pred_answer'),
|
||||
'gt_answer': gt_answer
|
||||
}
|
||||
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List) -> dict:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (dict): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
total_score = 0.
|
||||
for result in results:
|
||||
sample_score_list = []
|
||||
pred = ' '.join(result['pred_answer'].strip().lower().split())
|
||||
for gt in result['gt_answer']:
|
||||
gt = ' '.join(gt.strip().lower().split())
|
||||
dist = levenshtein_distance(gt, pred)
|
||||
length = max(
|
||||
len(gt.upper()), len(result['pred_answer'].upper()))
|
||||
sample_score_list.append(0.0 if length == 0 else float(dist) /
|
||||
float(length))
|
||||
|
||||
per_sample_score = 1. - min(sample_score_list)
|
||||
if per_sample_score < self.threshold:
|
||||
per_sample_score = 0.
|
||||
|
||||
total_score += per_sample_score
|
||||
|
||||
total_score = total_score / len(results)
|
||||
return {'ANLS': total_score}
|
||||
|
||||
|
||||
def levenshtein_distance(s1, s2):
|
||||
if len(s1) > len(s2):
|
||||
s1, s2 = s2, s1
|
||||
|
||||
distances = range(len(s1) + 1)
|
||||
for i2, c2 in enumerate(s2):
|
||||
distances_ = [i2 + 1]
|
||||
for i1, c1 in enumerate(s1):
|
||||
if c1 == c2:
|
||||
distances_.append(distances[i1])
|
||||
else:
|
||||
distances_.append(1 + min((distances[i1], distances[i1 + 1],
|
||||
distances_[-1])))
|
||||
distances = distances_
|
||||
return distances[-1]
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ANLS import ANLS
|
||||
from .caption import COCOCaption
|
||||
from .gqa import GQAAcc
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
|
@ -17,5 +18,5 @@ __all__ = [
|
|||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
|
||||
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
|
||||
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
|
||||
'RetrievalAveragePrecision', 'ShapeBiasMetric'
|
||||
'RetrievalAveragePrecision', 'ShapeBiasMetric', 'ANLS'
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue