From 1b8e86dca675707605e7011ebd1ac0b34c8c05f0 Mon Sep 17 00:00:00 2001 From: liuyuan <3463423099@qq.com> Date: Fri, 19 May 2023 16:11:01 +0800 Subject: [PATCH 1/6] [Feature]: Add caption --- mmpretrain/datasets/scienceqa.py | 75 ++++++++++++++++---------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/mmpretrain/datasets/scienceqa.py b/mmpretrain/datasets/scienceqa.py index 391f7e1ac..f6ae7f835 100644 --- a/mmpretrain/datasets/scienceqa.py +++ b/mmpretrain/datasets/scienceqa.py @@ -62,42 +62,43 @@ class ScienceQA(BaseDataset): data_list = [] for data_id in current_data_split: ann = annotations[data_id] - if ann['image'] is not None: - data_info = { - 'image_id': - data_id, - 'question': - ann['question'], - 'choices': - ann['choices'], - 'answer': - ann['answer'], - 'hint': - ann['hint'], - 'image_name': - ann['image'], - 'task': - ann['task'], - 'grade': - ann['grade'], - 'subject': - ann['subject'], - 'topic': - ann['topic'], - 'category': - ann['category'], - 'skill': - ann['skill'], - 'lecture': - ann['lecture'], - 'solution': - ann['solution'], - 'split': - ann['split'], - 'img_path': - file_backend.join_path(img_prefix, data_id, - ann['image']), # noqa - } - data_list.append(data_info) + data_info = { + 'image_id': + data_id, + 'question': + ann['question'], + 'choices': + ann['choices'], + 'answer': + ann['answer'], + 'hint': + ann['hint'], + 'image_name': + ann['image'], + 'task': + ann['task'], + 'grade': + ann['grade'], + 'subject': + ann['subject'], + 'topic': + ann['topic'], + 'category': + ann['category'], + 'skill': + ann['skill'], + 'lecture': + ann['lecture'], + 'solution': + ann['solution'], + 'split': + ann['split'], + 'img_path': + file_backend.join_path(img_prefix, data_id, ann['image']) + if ann['image'] is not None else None, + 'caption': + ann['caption'], + } + data_list.append(data_info) return data_list From 87f849cbb6a7138d2dfdd054d9a791898eff8902 Mon Sep 17 00:00:00 2001 From: liuyuan <3463423099@qq.com> Date: Fri, 19 May 2023 18:35:44 +0800 Subject: [PATCH 2/6] [Feature]: Add scienceqa metric --- mmpretrain/evaluation/metrics/__init__.py | 3 +- mmpretrain/evaluation/metrics/scienceqa.py | 170 ++++++++++++++++++ .../test_metrics/test_scienceqa.py | 27 +++ 3 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 mmpretrain/evaluation/metrics/scienceqa.py create mode 100644 tests/test_evaluation/test_metrics/test_scienceqa.py diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py index 683cf72be..186cdd9f2 100644 --- a/mmpretrain/evaluation/metrics/__init__.py +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -3,6 +3,7 @@ from .caption import COCOCaption from .multi_label import AveragePrecision, MultiLabelMetric from .multi_task import MultiTasksMetric from .retrieval import RetrievalRecall +from .scienceqa import ScienceQAMetric from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric from .visual_grounding_eval import VisualGroundingMetric from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric @@ -12,5 +13,5 @@ __all__ = [ 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', - 'VisualGroundingMetric' + 'VisualGroundingMetric', 'ScienceQAMetric' ] diff --git a/mmpretrain/evaluation/metrics/scienceqa.py b/mmpretrain/evaluation/metrics/scienceqa.py new file mode 100644 index 000000000..f81d4ec16 --- /dev/null +++ b/mmpretrain/evaluation/metrics/scienceqa.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def get_pred_idx(prediction: str, choices: List[str], + options: List[str]) -> int: # noqa + """Get the index (e.g. 2) from the prediction (e.g. 'C') + + Args: + prediction (str): The prediction from the model, + from ['A', 'B', 'C', 'D', 'E'] + choices (List(str)): The choices for the question, + from ['A', 'B', 'C', 'D', 'E'] + options (List(str)): The options for the question, + from ['A', 'B', 'C', 'D', 'E'] + + Returns: + int: The index of the prediction, from [0, 1, 2, 3, 4] + """ + if prediction in options[:len(choices)]: + return options.index(prediction) + else: + return random.choice(range(len(choices))) + + +@METRICS.register_module() +class ScienceQAMetric(BaseMetric): + """Evaluation Metric for ScienceQA. + + Args: + options (List(str)): Options for each question. Defaults to + ["A", "B", "C", "D", "E"]. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + def __init__(self, + options: List[str] = ['A', 'B', 'C', 'D', 'E'], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.options = options + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + data_samples should contain the following keys: + 1. prediction (str): The prediction from the model, + from ['A', 'B', 'C', 'D', 'E'] + 2. choices (List(str)): The choices for the question, + from ['A', 'B', 'C', 'D', 'E'] + 3. grade (int): The grade for the question, from grade1 to grade12 + 4. subject (str): The subject for the question, from + ['natural science', 'social science', 'language science'] + 5. answer (str): The answer for the question, from + ['A', 'B', 'C', 'D', 'E'] + 6. hint (str): The hint for the question + 7. image (torch.Tensor): The image for the question + + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + choices = data_sample.get('choices') + result['prediction'] = get_pred_idx( + data_sample.get('prediction'), choices, self.options) + result['grade'] = data_sample.get('grade') + result['subject'] = data_sample.get('subject') + result['answer'] = self.options[data_sample.get('answer')] + image = data_sample.get('image') + hint = data_sample.get('hint') + result[ + 'no_context'] = True if image is None and hint is None else False # noqa + result['has_text'] = True if hint is not None else False + result['has_image'] = True if image is not None else False + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = dict() + + all_acc = [] + acc_natural = [] + acc_social = [] + acc_language = [] + acc_has_text = [] + acc_has_image = [] + acc_no_context = [] + acc_grade_1_6 = [] + acc_grade_7_12 = [] + + for result in results: + correct = result['prediction'] == result['answer'] + all_acc.append(correct) + # different subjects + if result['subject'] == 'natural science': + acc_natural.append(correct) + elif result['subject'] == 'social science': + acc_social.append(correct) + elif result['subject'] == 'language science': + acc_language.append(correct) + + # different context + if result['has_text']: + acc_has_text.append(correct) + elif result['has_image']: + acc_has_image.append(correct) + elif result['no_context']: + acc_no_context.append(correct) + + # different grade + if result['grade'] in [ + 'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6' + ]: + acc_grade_1_6.append(correct) + elif result['grade'] in [ + 'grade7', 'grade8', 'grade9', 'grade10', 'grade11', + 'grade12' + ]: + acc_grade_7_12.append(correct) + + metrics['all_acc'] = sum(all_acc) / len(all_acc) + if len(acc_natural) > 0: + metrics['acc_natural'] = sum(acc_natural) / len(acc_natural) + if len(acc_social) > 0: + metrics['acc_social'] = sum(acc_social) / len(acc_social) + if len(acc_language) > 0: + metrics['acc_language'] = sum(acc_language) / len(acc_language) + if len(acc_has_text) > 0: + metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text) + if len(acc_has_image) > 0: + metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image) + if len(acc_no_context) > 0: + metrics['acc_no_context'] = sum(acc_no_context) / len( + acc_no_context) + if len(acc_grade_1_6) > 0: + metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6) + if len(acc_grade_7_12) > 0: + metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len( + acc_grade_7_12) + + return metrics diff --git a/tests/test_evaluation/test_metrics/test_scienceqa.py b/tests/test_evaluation/test_metrics/test_scienceqa.py new file mode 100644 index 000000000..7e97d0e7b --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_scienceqa.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.evaluator import Evaluator + +from mmpretrain.structures import DataSample + + +class TestScienceQAMetric: + + def test_evaluate(self): + meta_info = { + 'choices': ['A', 'B', 'C', 'D'], + 'prediction': 'A', + 'grade': 'grade1', + 'subject': 'language science', + 'answer': 1, + 'hint': 'hint', + 'image': torch.ones((3, 224, 224)) + } + data_sample = DataSample(metainfo=meta_info) + data_samples = [data_sample for _ in range(10)] + evaluator = Evaluator(dict(type='mmpretrain.ScienceQAMetric')) + evaluator.process(data_samples) + res = evaluator.evaluate(4) + assert res['acc_grade_1_6'] == 0.0 + assert res['acc_language'] == 0.0 + assert res['all_acc'] == 0.0 From 1537d4659626299065b4043df0b713c9e6e655bd Mon Sep 17 00:00:00 2001 From: liuyuan <3463423099@qq.com> Date: Mon, 22 May 2023 11:28:21 +0800 Subject: [PATCH 3/6] [Feature]: Update scienceqa --- mmpretrain/datasets/scienceqa.py | 4 +++- mmpretrain/evaluation/metrics/scienceqa.py | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mmpretrain/datasets/scienceqa.py b/mmpretrain/datasets/scienceqa.py index f6ae7f835..968264486 100644 --- a/mmpretrain/datasets/scienceqa.py +++ b/mmpretrain/datasets/scienceqa.py @@ -69,7 +69,7 @@ class ScienceQA(BaseDataset): ann['question'], 'choices': ann['choices'], - 'answer': + 'gt_answer': ann['answer'], 'hint': ann['hint'], @@ -98,6 +98,8 @@ class ScienceQA(BaseDataset): if ann['image'] is not None else None, 'caption': ann['caption'], + 'has_image': + True if ann['image'] is not None else False, } data_list.append(data_info) diff --git a/mmpretrain/evaluation/metrics/scienceqa.py b/mmpretrain/evaluation/metrics/scienceqa.py index f81d4ec16..fab5c97c3 100644 --- a/mmpretrain/evaluation/metrics/scienceqa.py +++ b/mmpretrain/evaluation/metrics/scienceqa.py @@ -56,7 +56,7 @@ class ScienceQAMetric(BaseMetric): """Process one batch of data samples. data_samples should contain the following keys: - 1. prediction (str): The prediction from the model, + 1. pred_answer (str): The prediction from the model, from ['A', 'B', 'C', 'D', 'E'] 2. choices (List(str)): The choices for the question, from ['A', 'B', 'C', 'D', 'E'] @@ -66,7 +66,7 @@ class ScienceQAMetric(BaseMetric): 5. answer (str): The answer for the question, from ['A', 'B', 'C', 'D', 'E'] 6. hint (str): The hint for the question - 7. image (torch.Tensor): The image for the question + 7. has_image (bool): Whether or not the question has image The processed results should be stored in ``self.results``, which will @@ -80,7 +80,7 @@ class ScienceQAMetric(BaseMetric): result = dict() choices = data_sample.get('choices') result['prediction'] = get_pred_idx( - data_sample.get('prediction'), choices, self.options) + data_sample.get('pred_answer'), choices, self.options) result['grade'] = data_sample.get('grade') result['subject'] = data_sample.get('subject') result['answer'] = self.options[data_sample.get('answer')] @@ -89,7 +89,7 @@ class ScienceQAMetric(BaseMetric): result[ 'no_context'] = True if image is None and hint is None else False # noqa result['has_text'] = True if hint is not None else False - result['has_image'] = True if image is not None else False + result['has_image'] = data_sample.get('has_image', False) # Save the result to `self.results`. self.results.append(result) From b0ad99afb9a27b2ab3a746057c42cdefcc089126 Mon Sep 17 00:00:00 2001 From: liuyuan <3463423099@qq.com> Date: Mon, 22 May 2023 11:38:34 +0800 Subject: [PATCH 4/6] [Fix]: Fix bug --- mmpretrain/evaluation/metrics/scienceqa.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmpretrain/evaluation/metrics/scienceqa.py b/mmpretrain/evaluation/metrics/scienceqa.py index fab5c97c3..90a392042 100644 --- a/mmpretrain/evaluation/metrics/scienceqa.py +++ b/mmpretrain/evaluation/metrics/scienceqa.py @@ -83,13 +83,13 @@ class ScienceQAMetric(BaseMetric): data_sample.get('pred_answer'), choices, self.options) result['grade'] = data_sample.get('grade') result['subject'] = data_sample.get('subject') - result['answer'] = self.options[data_sample.get('answer')] - image = data_sample.get('image') + result['answer'] = self.options[data_sample.get('gt_answer')] hint = data_sample.get('hint') + has_image = data_sample.get('has_image', False) result[ - 'no_context'] = True if image is None and hint is None else False # noqa + 'no_context'] = True if not has_image and hint is None else False # noqa result['has_text'] = True if hint is not None else False - result['has_image'] = data_sample.get('has_image', False) + result['has_image'] = has_image # Save the result to `self.results`. self.results.append(result) From 13e4d6c512a483166f28ab353e222fd3dcf90807 Mon Sep 17 00:00:00 2001 From: liuyuan <3463423099@qq.com> Date: Mon, 22 May 2023 11:55:08 +0800 Subject: [PATCH 5/6] [Fix]: Fix UT --- mmpretrain/datasets/gqa_dataset.py | 1 + tests/test_evaluation/test_metrics/test_scienceqa.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 mmpretrain/datasets/gqa_dataset.py diff --git a/mmpretrain/datasets/gqa_dataset.py b/mmpretrain/datasets/gqa_dataset.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/mmpretrain/datasets/gqa_dataset.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_evaluation/test_metrics/test_scienceqa.py b/tests/test_evaluation/test_metrics/test_scienceqa.py index 7e97d0e7b..5df50aa30 100644 --- a/tests/test_evaluation/test_metrics/test_scienceqa.py +++ b/tests/test_evaluation/test_metrics/test_scienceqa.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch from mmengine.evaluator import Evaluator from mmpretrain.structures import DataSample @@ -10,12 +9,12 @@ class TestScienceQAMetric: def test_evaluate(self): meta_info = { 'choices': ['A', 'B', 'C', 'D'], - 'prediction': 'A', + 'pred_answer': 'A', 'grade': 'grade1', 'subject': 'language science', - 'answer': 1, + 'gt_answer': 1, 'hint': 'hint', - 'image': torch.ones((3, 224, 224)) + 'has_image': True } data_sample = DataSample(metainfo=meta_info) data_samples = [data_sample for _ in range(10)] From 74f24658e7e037746cc2d7641de662e9cb16278a Mon Sep 17 00:00:00 2001 From: liuyuan <3463423099@qq.com> Date: Mon, 22 May 2023 11:57:18 +0800 Subject: [PATCH 6/6] [Fix]: Delete GQA --- mmpretrain/datasets/gqa_dataset.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 mmpretrain/datasets/gqa_dataset.py diff --git a/mmpretrain/datasets/gqa_dataset.py b/mmpretrain/datasets/gqa_dataset.py deleted file mode 100644 index ef101fec6..000000000 --- a/mmpretrain/datasets/gqa_dataset.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved.