[Fix] Fix scienceqa (#1581)
parent
023d6869bd
commit
be389eb846
|
@ -96,8 +96,6 @@ class ScienceQA(BaseDataset):
|
|||
'img_path':
|
||||
file_backend.join_path(img_prefix, data_id, ann['image'])
|
||||
if ann['image'] is not None else None,
|
||||
'caption':
|
||||
ann['caption'],
|
||||
'has_image':
|
||||
True if ann['image'] is not None else False,
|
||||
}
|
||||
|
|
|
@ -83,7 +83,7 @@ class ScienceQAMetric(BaseMetric):
|
|||
data_sample.get('pred_answer'), choices, self.options)
|
||||
result['grade'] = data_sample.get('grade')
|
||||
result['subject'] = data_sample.get('subject')
|
||||
result['answer'] = self.options[data_sample.get('gt_answer')]
|
||||
result['answer'] = data_sample.get('gt_answer')
|
||||
hint = data_sample.get('hint')
|
||||
has_image = data_sample.get('has_image', False)
|
||||
result[
|
||||
|
|
|
@ -24,3 +24,21 @@ class TestScienceQAMetric:
|
|||
assert res['acc_grade_1_6'] == 0.0
|
||||
assert res['acc_language'] == 0.0
|
||||
assert res['all_acc'] == 0.0
|
||||
|
||||
meta_info = {
|
||||
'choices': ['A', 'B', 'C', 'D'],
|
||||
'pred_answer': 'A',
|
||||
'grade': 'grade1',
|
||||
'subject': 'language science',
|
||||
'gt_answer': 0,
|
||||
'hint': 'hint',
|
||||
'has_image': True
|
||||
}
|
||||
data_sample = DataSample(metainfo=meta_info)
|
||||
data_samples = [data_sample for _ in range(10)]
|
||||
evaluator = Evaluator(dict(type='mmpretrain.ScienceQAMetric'))
|
||||
evaluator.process(data_samples)
|
||||
res = evaluator.evaluate(4)
|
||||
assert res['acc_grade_1_6'] == 1.0
|
||||
assert res['acc_language'] == 1.0
|
||||
assert res['all_acc'] == 1.0
|
||||
|
|
Loading…
Reference in New Issue