[Fix] Fix scienceqa (#1581)

pull/1586/head
Yuan Liu 2023-05-22 16:10:17 +08:00 committed by GitHub
parent 023d6869bd
commit be389eb846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 3 deletions

View File

@ -96,8 +96,6 @@ class ScienceQA(BaseDataset):
'img_path':
file_backend.join_path(img_prefix, data_id, ann['image'])
if ann['image'] is not None else None,
'caption':
ann['caption'],
'has_image':
True if ann['image'] is not None else False,
}

View File

@ -83,7 +83,7 @@ class ScienceQAMetric(BaseMetric):
data_sample.get('pred_answer'), choices, self.options)
result['grade'] = data_sample.get('grade')
result['subject'] = data_sample.get('subject')
result['answer'] = self.options[data_sample.get('gt_answer')]
result['answer'] = data_sample.get('gt_answer')
hint = data_sample.get('hint')
has_image = data_sample.get('has_image', False)
result[

View File

@ -24,3 +24,21 @@ class TestScienceQAMetric:
assert res['acc_grade_1_6'] == 0.0
assert res['acc_language'] == 0.0
assert res['all_acc'] == 0.0
meta_info = {
'choices': ['A', 'B', 'C', 'D'],
'pred_answer': 'A',
'grade': 'grade1',
'subject': 'language science',
'gt_answer': 0,
'hint': 'hint',
'has_image': True
}
data_sample = DataSample(metainfo=meta_info)
data_samples = [data_sample for _ in range(10)]
evaluator = Evaluator(dict(type='mmpretrain.ScienceQAMetric'))
evaluator.process(data_samples)
res = evaluator.evaluate(4)
assert res['acc_grade_1_6'] == 1.0
assert res['acc_language'] == 1.0
assert res['all_acc'] == 1.0