2023-05-19 18:35:44 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
from mmengine.evaluator import Evaluator
|
|
|
|
|
|
|
|
from mmpretrain.structures import DataSample
|
|
|
|
|
|
|
|
|
|
|
|
class TestScienceQAMetric:
|
|
|
|
|
|
|
|
def test_evaluate(self):
|
|
|
|
meta_info = {
|
|
|
|
'choices': ['A', 'B', 'C', 'D'],
|
2023-05-22 11:55:08 +08:00
|
|
|
'pred_answer': 'A',
|
2023-05-19 18:35:44 +08:00
|
|
|
'grade': 'grade1',
|
|
|
|
'subject': 'language science',
|
2023-05-22 11:55:08 +08:00
|
|
|
'gt_answer': 1,
|
2023-05-19 18:35:44 +08:00
|
|
|
'hint': 'hint',
|
2023-05-22 11:55:08 +08:00
|
|
|
'has_image': True
|
2023-05-19 18:35:44 +08:00
|
|
|
}
|
|
|
|
data_sample = DataSample(metainfo=meta_info)
|
|
|
|
data_samples = [data_sample for _ in range(10)]
|
|
|
|
evaluator = Evaluator(dict(type='mmpretrain.ScienceQAMetric'))
|
|
|
|
evaluator.process(data_samples)
|
|
|
|
res = evaluator.evaluate(4)
|
|
|
|
assert res['acc_grade_1_6'] == 0.0
|
|
|
|
assert res['acc_language'] == 0.0
|
|
|
|
assert res['all_acc'] == 0.0
|
2023-05-22 16:10:17 +08:00
|
|
|
|
|
|
|
meta_info = {
|
|
|
|
'choices': ['A', 'B', 'C', 'D'],
|
|
|
|
'pred_answer': 'A',
|
|
|
|
'grade': 'grade1',
|
|
|
|
'subject': 'language science',
|
|
|
|
'gt_answer': 0,
|
|
|
|
'hint': 'hint',
|
|
|
|
'has_image': True
|
|
|
|
}
|
|
|
|
data_sample = DataSample(metainfo=meta_info)
|
|
|
|
data_samples = [data_sample for _ in range(10)]
|
|
|
|
evaluator = Evaluator(dict(type='mmpretrain.ScienceQAMetric'))
|
|
|
|
evaluator.process(data_samples)
|
|
|
|
res = evaluator.evaluate(4)
|
|
|
|
assert res['acc_grade_1_6'] == 1.0
|
|
|
|
assert res['acc_language'] == 1.0
|
|
|
|
assert res['all_acc'] == 1.0
|