[Feature]: Add scienceqa metric
parent
1b8e86dca6
commit
87f849cbb6
|
@ -3,6 +3,7 @@ from .caption import COCOCaption
|
|||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
from .multi_task import MultiTasksMetric
|
||||
from .retrieval import RetrievalRecall
|
||||
from .scienceqa import ScienceQAMetric
|
||||
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
|
||||
from .visual_grounding_eval import VisualGroundingMetric
|
||||
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
|
||||
|
@ -12,5 +13,5 @@ __all__ = [
|
|||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
|
||||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
|
||||
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
|
||||
'VisualGroundingMetric'
|
||||
'VisualGroundingMetric', 'ScienceQAMetric'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmpretrain.registry import METRICS
|
||||
|
||||
|
||||
def get_pred_idx(prediction: str, choices: List[str],
|
||||
options: List[str]) -> int: # noqa
|
||||
"""Get the index (e.g. 2) from the prediction (e.g. 'C')
|
||||
|
||||
Args:
|
||||
prediction (str): The prediction from the model,
|
||||
from ['A', 'B', 'C', 'D', 'E']
|
||||
choices (List(str)): The choices for the question,
|
||||
from ['A', 'B', 'C', 'D', 'E']
|
||||
options (List(str)): The options for the question,
|
||||
from ['A', 'B', 'C', 'D', 'E']
|
||||
|
||||
Returns:
|
||||
int: The index of the prediction, from [0, 1, 2, 3, 4]
|
||||
"""
|
||||
if prediction in options[:len(choices)]:
|
||||
return options.index(prediction)
|
||||
else:
|
||||
return random.choice(range(len(choices)))
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ScienceQAMetric(BaseMetric):
|
||||
"""Evaluation Metric for ScienceQA.
|
||||
|
||||
Args:
|
||||
options (List(str)): Options for each question. Defaults to
|
||||
["A", "B", "C", "D", "E"].
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Should be modified according to the
|
||||
`retrieval_type` for unambiguous results. Defaults to TR.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
options: List[str] = ['A', 'B', 'C', 'D', 'E'],
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
self.options = options
|
||||
|
||||
def process(self, data_batch, data_samples) -> None:
|
||||
"""Process one batch of data samples.
|
||||
|
||||
data_samples should contain the following keys:
|
||||
1. prediction (str): The prediction from the model,
|
||||
from ['A', 'B', 'C', 'D', 'E']
|
||||
2. choices (List(str)): The choices for the question,
|
||||
from ['A', 'B', 'C', 'D', 'E']
|
||||
3. grade (int): The grade for the question, from grade1 to grade12
|
||||
4. subject (str): The subject for the question, from
|
||||
['natural science', 'social science', 'language science']
|
||||
5. answer (str): The answer for the question, from
|
||||
['A', 'B', 'C', 'D', 'E']
|
||||
6. hint (str): The hint for the question
|
||||
7. image (torch.Tensor): The image for the question
|
||||
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
choices = data_sample.get('choices')
|
||||
result['prediction'] = get_pred_idx(
|
||||
data_sample.get('prediction'), choices, self.options)
|
||||
result['grade'] = data_sample.get('grade')
|
||||
result['subject'] = data_sample.get('subject')
|
||||
result['answer'] = self.options[data_sample.get('answer')]
|
||||
image = data_sample.get('image')
|
||||
hint = data_sample.get('hint')
|
||||
result[
|
||||
'no_context'] = True if image is None and hint is None else False # noqa
|
||||
result['has_text'] = True if hint is not None else False
|
||||
result['has_image'] = True if image is not None else False
|
||||
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List) -> dict:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (dict): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
# NOTICE: don't access `self.results` from the method.
|
||||
metrics = dict()
|
||||
|
||||
all_acc = []
|
||||
acc_natural = []
|
||||
acc_social = []
|
||||
acc_language = []
|
||||
acc_has_text = []
|
||||
acc_has_image = []
|
||||
acc_no_context = []
|
||||
acc_grade_1_6 = []
|
||||
acc_grade_7_12 = []
|
||||
|
||||
for result in results:
|
||||
correct = result['prediction'] == result['answer']
|
||||
all_acc.append(correct)
|
||||
# different subjects
|
||||
if result['subject'] == 'natural science':
|
||||
acc_natural.append(correct)
|
||||
elif result['subject'] == 'social science':
|
||||
acc_social.append(correct)
|
||||
elif result['subject'] == 'language science':
|
||||
acc_language.append(correct)
|
||||
|
||||
# different context
|
||||
if result['has_text']:
|
||||
acc_has_text.append(correct)
|
||||
elif result['has_image']:
|
||||
acc_has_image.append(correct)
|
||||
elif result['no_context']:
|
||||
acc_no_context.append(correct)
|
||||
|
||||
# different grade
|
||||
if result['grade'] in [
|
||||
'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6'
|
||||
]:
|
||||
acc_grade_1_6.append(correct)
|
||||
elif result['grade'] in [
|
||||
'grade7', 'grade8', 'grade9', 'grade10', 'grade11',
|
||||
'grade12'
|
||||
]:
|
||||
acc_grade_7_12.append(correct)
|
||||
|
||||
metrics['all_acc'] = sum(all_acc) / len(all_acc)
|
||||
if len(acc_natural) > 0:
|
||||
metrics['acc_natural'] = sum(acc_natural) / len(acc_natural)
|
||||
if len(acc_social) > 0:
|
||||
metrics['acc_social'] = sum(acc_social) / len(acc_social)
|
||||
if len(acc_language) > 0:
|
||||
metrics['acc_language'] = sum(acc_language) / len(acc_language)
|
||||
if len(acc_has_text) > 0:
|
||||
metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text)
|
||||
if len(acc_has_image) > 0:
|
||||
metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image)
|
||||
if len(acc_no_context) > 0:
|
||||
metrics['acc_no_context'] = sum(acc_no_context) / len(
|
||||
acc_no_context)
|
||||
if len(acc_grade_1_6) > 0:
|
||||
metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6)
|
||||
if len(acc_grade_7_12) > 0:
|
||||
metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len(
|
||||
acc_grade_7_12)
|
||||
|
||||
return metrics
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.evaluator import Evaluator
|
||||
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class TestScienceQAMetric:
|
||||
|
||||
def test_evaluate(self):
|
||||
meta_info = {
|
||||
'choices': ['A', 'B', 'C', 'D'],
|
||||
'prediction': 'A',
|
||||
'grade': 'grade1',
|
||||
'subject': 'language science',
|
||||
'answer': 1,
|
||||
'hint': 'hint',
|
||||
'image': torch.ones((3, 224, 224))
|
||||
}
|
||||
data_sample = DataSample(metainfo=meta_info)
|
||||
data_samples = [data_sample for _ in range(10)]
|
||||
evaluator = Evaluator(dict(type='mmpretrain.ScienceQAMetric'))
|
||||
evaluator.process(data_samples)
|
||||
res = evaluator.evaluate(4)
|
||||
assert res['acc_grade_1_6'] == 0.0
|
||||
assert res['acc_language'] == 0.0
|
||||
assert res['all_acc'] == 0.0
|
Loading…
Reference in New Issue