mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
* [Feature]: Add GQA dataset * [Feature]: Add GQA * [Feature]: Add GQA UT * [Fix]: Fix hint * [Feature]: Add BLIP2 GQA * [Fix]: Fix lint * [Feature]: Update anno link * [Fix]: Update docstring * [Feature]: Update all links
31 lines
959 B
Python
31 lines
959 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from mmengine.evaluator import Evaluator
|
|
|
|
from mmpretrain.structures import DataSample
|
|
|
|
|
|
class TestScienceQAMetric:
|
|
|
|
def test_evaluate(self):
|
|
meta_info = {
|
|
'pred_answer': 'dog',
|
|
'gt_answer': 'dog',
|
|
}
|
|
data_sample = DataSample(metainfo=meta_info)
|
|
data_samples = [data_sample for _ in range(10)]
|
|
evaluator = Evaluator(dict(type='mmpretrain.GQAAcc'))
|
|
evaluator.process(data_samples)
|
|
res = evaluator.evaluate(4)
|
|
assert res['GQA/acc'] == 1.0
|
|
|
|
meta_info = {
|
|
'pred_answer': 'dog',
|
|
'gt_answer': 'cat',
|
|
}
|
|
data_sample = DataSample(metainfo=meta_info)
|
|
data_samples = [data_sample for _ in range(10)]
|
|
evaluator = Evaluator(dict(type='mmpretrain.GQAAcc'))
|
|
evaluator.process(data_samples)
|
|
res = evaluator.evaluate(4)
|
|
assert res['GQA/acc'] == 0.0
|