From 7581b762335a1666ce36d1bfc1c06794903dfb88 Mon Sep 17 00:00:00 2001 From: Yike Yuan <32432002+yyk-wew@users.noreply.github.com> Date: Thu, 15 Jun 2023 19:17:02 +0800 Subject: [PATCH] [Feature] Add support for vsr dataset (#1634) * add VSR dataset * [Fix] Modify example and load gt_answer as string. --------- Co-authored-by: ZhangYuanhan-AI --- configs/_base_/datasets/vsr.py | 81 +++++++++++++++++++++++++++++++++ mmpretrain/datasets/__init__.py | 3 +- mmpretrain/datasets/vsr.py | 55 ++++++++++++++++++++++ 3 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 configs/_base_/datasets/vsr.py create mode 100644 mmpretrain/datasets/vsr.py diff --git a/configs/_base_/datasets/vsr.py b/configs/_base_/datasets/vsr.py new file mode 100644 index 00000000..0fa9b899 --- /dev/null +++ b/configs/_base_/datasets/vsr.py @@ -0,0 +1,81 @@ +# data settings + +data_preprocessor = dict( + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=384, + interpolation='bicubic', + backend='pillow'), + dict( + type='PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(480, 480), + interpolation='bicubic', + backend='pillow'), + dict( + type='CleanCaption', + keys=['question'], + ), + dict( + type='PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ), +] + +train_dataloader = dict( + batch_size=16, + num_workers=8, + dataset=dict( + type='VSR', + data_root='data/coco', + data_prefix='', + ann_file='annotations/train.json', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, + drop_last=True, +) + +val_dataloader = dict( + batch_size=16, + num_workers=8, + dataset=dict( + type='VSR', + data_root='data/coco', + data_prefix='', + ann_file='annotations/val.json', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='VSRAcc') + +test_dataloader = dict( + batch_size=16, + num_workers=8, + dataset=dict( + type='VSR', + data_root='data/coco', + data_prefix='', + ann_file='annotations/test.json', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +test_evaluator = val_evaluator diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index b7267133..d9e241a2 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -45,9 +45,10 @@ if WITH_MULTIMODAL: from .scienceqa import ScienceQA from .textvqa import TextVQA from .visual_genome import VisualGenomeQA + from .vsr import VSR __all__.extend([ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA', - 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA' + 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR' ]) diff --git a/mmpretrain/datasets/vsr.py b/mmpretrain/datasets/vsr.py new file mode 100644 index 00000000..7b109592 --- /dev/null +++ b/mmpretrain/datasets/vsr.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VSR(BaseDataset): + """VSR: Visual Spatial Reasoning dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # ann example + # { + # "image": "train2017/000000372029.jpg", + # "question": "The dog is on the surfboard.", + # "answer": true + # } + data_info = dict() + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image']) + data_info['question'] = ann['question'] + data_info['gt_answer'] = 'yes' if ann['answer'] else 'no' + + data_list.append(data_info) + + return data_list