diff --git a/configs/_base_/datasets/vizwiz.py b/configs/_base_/datasets/vizwiz.py new file mode 100644 index 000000000..bb7156c07 --- /dev/null +++ b/configs/_base_/datasets/vizwiz.py @@ -0,0 +1,80 @@ +# data settings + +data_preprocessor = dict( + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=384, + interpolation='bicubic', + backend='pillow'), + dict( + type='PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(480, 480), + interpolation='bicubic', + backend='pillow'), + dict( + type='CleanCaption', + keys=['question'], + ), + dict( + type='PackInputs', + algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], + meta_keys=['question_id', 'image_id'], + ), +] + +train_dataloader = dict( + batch_size=16, + num_workers=8, + dataset=dict( + type='VizWiz', + data_root='data/vizwiz/Images', + data_prefix='', + ann_file='Annotations/train.json', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, + drop_last=True, +) + +val_dataloader = dict( + batch_size=16, + num_workers=8, + dataset=dict( + type='VizWiz', + data_root='data/vizwiz/Images', + data_prefix='', + ann_file='Annotations/val.json', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='VizWizAcc') + +test_dataloader = dict( + batch_size=16, + num_workers=8, + dataset=dict( + type='VizWiz', + data_root='data/vizwiz/Images', + data_prefix='', + ann_file='Annotations/test.json', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='ReportVQA', file_path='vqa_test.json') diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index d9e241a23..dd522325a 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -45,10 +45,11 @@ if WITH_MULTIMODAL: from .scienceqa import ScienceQA from .textvqa import TextVQA from .visual_genome import VisualGenomeQA + from .vizwiz import VizWiz from .vsr import VSR __all__.extend([ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA', - 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR' + 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR', 'VizWiz' ]) diff --git a/mmpretrain/datasets/vizwiz.py b/mmpretrain/datasets/vizwiz.py new file mode 100644 index 000000000..7b5dd3945 --- /dev/null +++ b/mmpretrain/datasets/vizwiz.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VizWiz(BaseDataset): + """VizWiz dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # { + # "image": "VizWiz_val_00000001.jpg", + # "question": "Can you tell me what this medicine is please?", + # "answers": [ + # { + # "answer": "no", + # "answer_confidence": "yes" + # }, + # { + # "answer": "unanswerable", + # "answer_confidence": "yes" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "unanswerable", + # "answer_confidence": "yes" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time cold medicine", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time medicine", + # "answer_confidence": "yes" + # } + # ], + # "answer_type": "other", + # "answerable": 1 + # }, + data_info = dict() + data_info['question'] = ann['question'] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image']) + + if 'answerable' not in ann: + data_list.append(data_info) + else: + if ann['answerable'] == 1: + # add answer_weight & answer_count, delete duplicate answer + answers = [] + for item in ann.pop('answers'): + if item['answer_confidence'] == 'yes' and item[ + 'answer'] != 'unanswerable': + answers.append(item['answer']) + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + # data_info.update(ann) + data_list.append(data_info) + + return data_list