[Feature] Add support for vsr dataset (#1634)
* add VSR dataset * [Fix] Modify example and load gt_answer as string. --------- Co-authored-by: ZhangYuanhan-AI <yuanhan002@ntu.edu.sg>pull/1636/head
parent
53648baca5
commit
7581b76233
|
@ -0,0 +1,81 @@
|
|||
# data settings
|
||||
|
||||
data_preprocessor = dict(
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(480, 480),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='CleanCaption',
|
||||
keys=['question'],
|
||||
),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='VSR',
|
||||
data_root='data/coco',
|
||||
data_prefix='',
|
||||
ann_file='annotations/train.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='VSR',
|
||||
data_root='data/coco',
|
||||
data_prefix='',
|
||||
ann_file='annotations/val.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='VSRAcc')
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='VSR',
|
||||
data_root='data/coco',
|
||||
data_prefix='',
|
||||
ann_file='annotations/test.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
test_evaluator = val_evaluator
|
|
@ -45,9 +45,10 @@ if WITH_MULTIMODAL:
|
|||
from .scienceqa import ScienceQA
|
||||
from .textvqa import TextVQA
|
||||
from .visual_genome import VisualGenomeQA
|
||||
from .vsr import VSR
|
||||
|
||||
__all__.extend([
|
||||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
|
||||
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA'
|
||||
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR'
|
||||
])
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine.dataset import BaseDataset
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class VSR(BaseDataset):
|
||||
"""VSR: Visual Spatial Reasoning dataset.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
||||
and ``question_file``.
|
||||
data_prefix (str): The directory of images.
|
||||
ann_file (str, optional): Annotation file path for training and
|
||||
validation. Defaults to an empty string.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_root: str,
|
||||
data_prefix: str,
|
||||
ann_file: str = '',
|
||||
**kwarg):
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path=data_prefix),
|
||||
ann_file=ann_file,
|
||||
**kwarg,
|
||||
)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
annotations = mmengine.load(self.ann_file)
|
||||
|
||||
data_list = []
|
||||
for ann in annotations:
|
||||
# ann example
|
||||
# {
|
||||
# "image": "train2017/000000372029.jpg",
|
||||
# "question": "The dog is on the surfboard.",
|
||||
# "answer": true
|
||||
# }
|
||||
data_info = dict()
|
||||
data_info['img_path'] = mmengine.join_path(
|
||||
self.data_prefix['img_path'], ann['image'])
|
||||
data_info['question'] = ann['question']
|
||||
data_info['gt_answer'] = 'yes' if ann['answer'] else 'no'
|
||||
|
||||
data_list.append(data_info)
|
||||
|
||||
return data_list
|
Loading…
Reference in New Issue