diff --git a/mmpretrain/datasets/scienceqa.py b/mmpretrain/datasets/scienceqa.py index f6ae7f83..96826448 100644 --- a/mmpretrain/datasets/scienceqa.py +++ b/mmpretrain/datasets/scienceqa.py @@ -69,7 +69,7 @@ class ScienceQA(BaseDataset): ann['question'], 'choices': ann['choices'], - 'answer': + 'gt_answer': ann['answer'], 'hint': ann['hint'], @@ -98,6 +98,8 @@ class ScienceQA(BaseDataset): if ann['image'] is not None else None, 'caption': ann['caption'], + 'has_image': + True if ann['image'] is not None else False, } data_list.append(data_info) diff --git a/mmpretrain/evaluation/metrics/scienceqa.py b/mmpretrain/evaluation/metrics/scienceqa.py index f81d4ec1..fab5c97c 100644 --- a/mmpretrain/evaluation/metrics/scienceqa.py +++ b/mmpretrain/evaluation/metrics/scienceqa.py @@ -56,7 +56,7 @@ class ScienceQAMetric(BaseMetric): """Process one batch of data samples. data_samples should contain the following keys: - 1. prediction (str): The prediction from the model, + 1. pred_answer (str): The prediction from the model, from ['A', 'B', 'C', 'D', 'E'] 2. choices (List(str)): The choices for the question, from ['A', 'B', 'C', 'D', 'E'] @@ -66,7 +66,7 @@ class ScienceQAMetric(BaseMetric): 5. answer (str): The answer for the question, from ['A', 'B', 'C', 'D', 'E'] 6. hint (str): The hint for the question - 7. image (torch.Tensor): The image for the question + 7. has_image (bool): Whether or not the question has image The processed results should be stored in ``self.results``, which will @@ -80,7 +80,7 @@ class ScienceQAMetric(BaseMetric): result = dict() choices = data_sample.get('choices') result['prediction'] = get_pred_idx( - data_sample.get('prediction'), choices, self.options) + data_sample.get('pred_answer'), choices, self.options) result['grade'] = data_sample.get('grade') result['subject'] = data_sample.get('subject') result['answer'] = self.options[data_sample.get('answer')] @@ -89,7 +89,7 @@ class ScienceQAMetric(BaseMetric): result[ 'no_context'] = True if image is None and hint is None else False # noqa result['has_text'] = True if hint is not None else False - result['has_image'] = True if image is not None else False + result['has_image'] = data_sample.get('has_image', False) # Save the result to `self.results`. self.results.append(result)