316 lines
10 KiB
Python
316 lines
10 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Partly adopted from https://github.com/GT-Vision-Lab/VQA
|
|
# Copyright (c) 2014, Aishwarya Agrawal
|
|
from typing import List, Optional
|
|
|
|
import mmengine
|
|
from mmengine.evaluator import BaseMetric
|
|
from mmengine.logging import MMLogger
|
|
|
|
from mmpretrain.registry import METRICS
|
|
|
|
|
|
def _process_punctuation(inText):
|
|
import re
|
|
outText = inText
|
|
punct = [
|
|
';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
|
|
'>', '<', '@', '`', ',', '?', '!'
|
|
]
|
|
commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
|
|
periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
|
|
for p in punct:
|
|
if (p + ' ' in inText or ' ' + p in inText) or (re.search(
|
|
commaStrip, inText) is not None):
|
|
outText = outText.replace(p, '')
|
|
else:
|
|
outText = outText.replace(p, ' ')
|
|
outText = periodStrip.sub('', outText, re.UNICODE)
|
|
return outText
|
|
|
|
|
|
def _process_digit_article(inText):
|
|
outText = []
|
|
tempText = inText.lower().split()
|
|
articles = ['a', 'an', 'the']
|
|
manualMap = {
|
|
'none': '0',
|
|
'zero': '0',
|
|
'one': '1',
|
|
'two': '2',
|
|
'three': '3',
|
|
'four': '4',
|
|
'five': '5',
|
|
'six': '6',
|
|
'seven': '7',
|
|
'eight': '8',
|
|
'nine': '9',
|
|
'ten': '10',
|
|
}
|
|
contractions = {
|
|
'aint': "ain't",
|
|
'arent': "aren't",
|
|
'cant': "can't",
|
|
'couldve': "could've",
|
|
'couldnt': "couldn't",
|
|
"couldn'tve": "couldn't've",
|
|
"couldnt've": "couldn't've",
|
|
'didnt': "didn't",
|
|
'doesnt': "doesn't",
|
|
'dont': "don't",
|
|
'hadnt': "hadn't",
|
|
"hadnt've": "hadn't've",
|
|
"hadn'tve": "hadn't've",
|
|
'hasnt': "hasn't",
|
|
'havent': "haven't",
|
|
'hed': "he'd",
|
|
"hed've": "he'd've",
|
|
"he'dve": "he'd've",
|
|
'hes': "he's",
|
|
'howd': "how'd",
|
|
'howll': "how'll",
|
|
'hows': "how's",
|
|
"Id've": "I'd've",
|
|
"I'dve": "I'd've",
|
|
'Im': "I'm",
|
|
'Ive': "I've",
|
|
'isnt': "isn't",
|
|
'itd': "it'd",
|
|
"itd've": "it'd've",
|
|
"it'dve": "it'd've",
|
|
'itll': "it'll",
|
|
"let's": "let's",
|
|
'maam': "ma'am",
|
|
'mightnt': "mightn't",
|
|
"mightnt've": "mightn't've",
|
|
"mightn'tve": "mightn't've",
|
|
'mightve': "might've",
|
|
'mustnt': "mustn't",
|
|
'mustve': "must've",
|
|
'neednt': "needn't",
|
|
'notve': "not've",
|
|
'oclock': "o'clock",
|
|
'oughtnt': "oughtn't",
|
|
"ow's'at": "'ow's'at",
|
|
"'ows'at": "'ow's'at",
|
|
"'ow'sat": "'ow's'at",
|
|
'shant': "shan't",
|
|
"shed've": "she'd've",
|
|
"she'dve": "she'd've",
|
|
"she's": "she's",
|
|
'shouldve': "should've",
|
|
'shouldnt': "shouldn't",
|
|
"shouldnt've": "shouldn't've",
|
|
"shouldn'tve": "shouldn't've",
|
|
"somebody'd": 'somebodyd',
|
|
"somebodyd've": "somebody'd've",
|
|
"somebody'dve": "somebody'd've",
|
|
'somebodyll': "somebody'll",
|
|
'somebodys': "somebody's",
|
|
'someoned': "someone'd",
|
|
"someoned've": "someone'd've",
|
|
"someone'dve": "someone'd've",
|
|
'someonell': "someone'll",
|
|
'someones': "someone's",
|
|
'somethingd': "something'd",
|
|
"somethingd've": "something'd've",
|
|
"something'dve": "something'd've",
|
|
'somethingll': "something'll",
|
|
'thats': "that's",
|
|
'thered': "there'd",
|
|
"thered've": "there'd've",
|
|
"there'dve": "there'd've",
|
|
'therere': "there're",
|
|
'theres': "there's",
|
|
'theyd': "they'd",
|
|
"theyd've": "they'd've",
|
|
"they'dve": "they'd've",
|
|
'theyll': "they'll",
|
|
'theyre': "they're",
|
|
'theyve': "they've",
|
|
'twas': "'twas",
|
|
'wasnt': "wasn't",
|
|
"wed've": "we'd've",
|
|
"we'dve": "we'd've",
|
|
'weve': "we've",
|
|
'werent': "weren't",
|
|
'whatll': "what'll",
|
|
'whatre': "what're",
|
|
'whats': "what's",
|
|
'whatve': "what've",
|
|
'whens': "when's",
|
|
'whered': "where'd",
|
|
'wheres': "where's",
|
|
'whereve': "where've",
|
|
'whod': "who'd",
|
|
"whod've": "who'd've",
|
|
"who'dve": "who'd've",
|
|
'wholl': "who'll",
|
|
'whos': "who's",
|
|
'whove': "who've",
|
|
'whyll': "why'll",
|
|
'whyre': "why're",
|
|
'whys': "why's",
|
|
'wont': "won't",
|
|
'wouldve': "would've",
|
|
'wouldnt': "wouldn't",
|
|
"wouldnt've": "wouldn't've",
|
|
"wouldn'tve": "wouldn't've",
|
|
'yall': "y'all",
|
|
"yall'll": "y'all'll",
|
|
"y'allll": "y'all'll",
|
|
"yall'd've": "y'all'd've",
|
|
"y'alld've": "y'all'd've",
|
|
"y'all'dve": "y'all'd've",
|
|
'youd': "you'd",
|
|
"youd've": "you'd've",
|
|
"you'dve": "you'd've",
|
|
'youll': "you'll",
|
|
'youre': "you're",
|
|
'youve': "you've",
|
|
}
|
|
for word in tempText:
|
|
word = manualMap.setdefault(word, word)
|
|
if word not in articles:
|
|
outText.append(word)
|
|
for wordId, word in enumerate(outText):
|
|
if word in contractions:
|
|
outText[wordId] = contractions[word]
|
|
outText = ' '.join(outText)
|
|
return outText
|
|
|
|
|
|
@METRICS.register_module()
|
|
class VQAAcc(BaseMetric):
|
|
'''VQA Acc metric.
|
|
Args:
|
|
|
|
collect_device (str): Device name used for collecting results from
|
|
different ranks during distributed training. Must be 'cpu' or
|
|
'gpu'. Defaults to 'cpu'.
|
|
prefix (str, optional): The prefix that will be added in the metric
|
|
names to disambiguate homonymous metrics of different evaluators.
|
|
If prefix is not provided in the argument, self.default_prefix
|
|
will be used instead. Should be modified according to the
|
|
`retrieval_type` for unambiguous results. Defaults to TR.
|
|
'''
|
|
default_prefix = 'VQA'
|
|
|
|
def __init__(self,
|
|
full_score_weight: float = 0.3,
|
|
collect_device: str = 'cpu',
|
|
prefix: Optional[str] = None):
|
|
super().__init__(collect_device=collect_device, prefix=prefix)
|
|
self.full_score_weight = full_score_weight
|
|
|
|
def process(self, data_batch, data_samples):
|
|
"""Process one batch of data samples.
|
|
|
|
The processed results should be stored in ``self.results``, which will
|
|
be used to computed the metrics when all batches have been processed.
|
|
|
|
Args:
|
|
data_batch: A batch of data from the dataloader.
|
|
data_samples (Sequence[dict]): A batch of outputs from the model.
|
|
"""
|
|
for sample in data_samples:
|
|
gt_answer = sample.get('gt_answer')
|
|
gt_answer_weight = sample.get('gt_answer_weight')
|
|
if isinstance(gt_answer, str):
|
|
gt_answer = [gt_answer]
|
|
if gt_answer_weight is None:
|
|
gt_answer_weight = [1. / (len(gt_answer))] * len(gt_answer)
|
|
|
|
result = {
|
|
'pred_answer': sample.get('pred_answer'),
|
|
'gt_answer': gt_answer,
|
|
'gt_answer_weight': gt_answer_weight,
|
|
}
|
|
|
|
self.results.append(result)
|
|
|
|
def compute_metrics(self, results: List):
|
|
"""Compute the metrics from processed results.
|
|
|
|
Args:
|
|
results (dict): The processed results of each batch.
|
|
|
|
Returns:
|
|
Dict: The computed metrics. The keys are the names of the metrics,
|
|
and the values are corresponding results.
|
|
"""
|
|
acc = []
|
|
for result in results:
|
|
pred_answer = self._process_answer(result['pred_answer'])
|
|
gt_answer = [
|
|
self._process_answer(answer) for answer in result['gt_answer']
|
|
]
|
|
answer_weight = result['gt_answer_weight']
|
|
|
|
weight_sum = 0
|
|
for i, gt in enumerate(gt_answer):
|
|
if gt == pred_answer:
|
|
weight_sum += answer_weight[i]
|
|
vqa_acc = min(1.0, weight_sum / self.full_score_weight)
|
|
acc.append(vqa_acc)
|
|
|
|
accuracy = sum(acc) / len(acc) * 100
|
|
|
|
metrics = {'acc': accuracy}
|
|
return metrics
|
|
|
|
def _process_answer(self, answer):
|
|
answer = answer.replace('\n', ' ')
|
|
answer = answer.replace('\t', ' ')
|
|
answer = answer.strip()
|
|
answer = _process_punctuation(answer)
|
|
answer = _process_digit_article(answer)
|
|
return answer
|
|
|
|
|
|
@METRICS.register_module()
|
|
class ReportVQA(BaseMetric):
|
|
"""Dump VQA result to the standard json format for VQA evaluation.
|
|
|
|
Args:
|
|
file_path (str): The file path to save the result file.
|
|
collect_device (str): Device name used for collecting results from
|
|
different ranks during distributed training. Must be 'cpu' or
|
|
'gpu'. Defaults to 'cpu'.
|
|
prefix (str, optional): The prefix that will be added in the metric
|
|
names to disambiguate homonymous metrics of different evaluators.
|
|
If prefix is not provided in the argument, self.default_prefix
|
|
will be used instead. Should be modified according to the
|
|
`retrieval_type` for unambiguous results. Defaults to TR.
|
|
"""
|
|
default_prefix = 'VQA'
|
|
|
|
def __init__(self,
|
|
file_path: str,
|
|
collect_device: str = 'cpu',
|
|
prefix: Optional[str] = None):
|
|
super().__init__(collect_device=collect_device, prefix=prefix)
|
|
if not file_path.endswith('.json'):
|
|
raise ValueError('The output file must be a json file.')
|
|
self.file_path = file_path
|
|
|
|
def process(self, data_batch, data_samples) -> None:
|
|
"""transfer tensors in predictions to CPU."""
|
|
for sample in data_samples:
|
|
question_id = sample['question_id']
|
|
pred_answer = sample['pred_answer']
|
|
|
|
result = {
|
|
'question_id': int(question_id),
|
|
'answer': pred_answer,
|
|
}
|
|
|
|
self.results.append(result)
|
|
|
|
def compute_metrics(self, results: List):
|
|
"""Dump the result to json file."""
|
|
mmengine.dump(results, self.file_path)
|
|
logger = MMLogger.get_current_instance()
|
|
logger.info(f'Results has been saved to {self.file_path}.')
|
|
return {}
|