mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Add support for VizWiz dataset. (#1636)
* add vizwiz * update dataset * [Fix] Build img_path in data_sample. * Fix isort. --------- Co-authored-by: ZhangYuanhan-AI <yuanhan002@ntu.edu.sg>
This commit is contained in:
parent
aac398a83f
commit
a673b048a5
80
configs/_base_/datasets/vizwiz.py
Normal file
80
configs/_base_/datasets/vizwiz.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# data settings
|
||||||
|
|
||||||
|
data_preprocessor = dict(
|
||||||
|
mean=[122.770938, 116.7460125, 104.09373615],
|
||||||
|
std=[68.5005327, 66.6321579, 70.32316305],
|
||||||
|
to_rgb=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=384,
|
||||||
|
interpolation='bicubic',
|
||||||
|
backend='pillow'),
|
||||||
|
dict(
|
||||||
|
type='PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='Resize',
|
||||||
|
scale=(480, 480),
|
||||||
|
interpolation='bicubic',
|
||||||
|
backend='pillow'),
|
||||||
|
dict(
|
||||||
|
type='CleanCaption',
|
||||||
|
keys=['question'],
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
type='PackInputs',
|
||||||
|
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||||
|
meta_keys=['question_id', 'image_id'],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=8,
|
||||||
|
dataset=dict(
|
||||||
|
type='VizWiz',
|
||||||
|
data_root='data/vizwiz/Images',
|
||||||
|
data_prefix='',
|
||||||
|
ann_file='Annotations/train.json',
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
persistent_workers=True,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=8,
|
||||||
|
dataset=dict(
|
||||||
|
type='VizWiz',
|
||||||
|
data_root='data/vizwiz/Images',
|
||||||
|
data_prefix='',
|
||||||
|
ann_file='Annotations/val.json',
|
||||||
|
pipeline=test_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
val_evaluator = dict(type='VizWizAcc')
|
||||||
|
|
||||||
|
test_dataloader = dict(
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=8,
|
||||||
|
dataset=dict(
|
||||||
|
type='VizWiz',
|
||||||
|
data_root='data/vizwiz/Images',
|
||||||
|
data_prefix='',
|
||||||
|
ann_file='Annotations/test.json',
|
||||||
|
pipeline=test_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
)
|
||||||
|
test_evaluator = dict(type='ReportVQA', file_path='vqa_test.json')
|
@ -45,10 +45,11 @@ if WITH_MULTIMODAL:
|
|||||||
from .scienceqa import ScienceQA
|
from .scienceqa import ScienceQA
|
||||||
from .textvqa import TextVQA
|
from .textvqa import TextVQA
|
||||||
from .visual_genome import VisualGenomeQA
|
from .visual_genome import VisualGenomeQA
|
||||||
|
from .vizwiz import VizWiz
|
||||||
from .vsr import VSR
|
from .vsr import VSR
|
||||||
|
|
||||||
__all__.extend([
|
__all__.extend([
|
||||||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||||
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
|
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
|
||||||
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR'
|
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR', 'VizWiz'
|
||||||
])
|
])
|
||||||
|
112
mmpretrain/datasets/vizwiz.py
Normal file
112
mmpretrain/datasets/vizwiz.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from collections import Counter
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
from mmengine.dataset import BaseDataset
|
||||||
|
|
||||||
|
from mmpretrain.registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class VizWiz(BaseDataset):
|
||||||
|
"""VizWiz dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
||||||
|
and ``question_file``.
|
||||||
|
data_prefix (str): The directory of images.
|
||||||
|
ann_file (str, optional): Annotation file path for training and
|
||||||
|
validation. Defaults to an empty string.
|
||||||
|
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
data_root: str,
|
||||||
|
data_prefix: str,
|
||||||
|
ann_file: str = '',
|
||||||
|
**kwarg):
|
||||||
|
super().__init__(
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path=data_prefix),
|
||||||
|
ann_file=ann_file,
|
||||||
|
**kwarg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_data_list(self) -> List[dict]:
|
||||||
|
"""Load data list."""
|
||||||
|
annotations = mmengine.load(self.ann_file)
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for ann in annotations:
|
||||||
|
# {
|
||||||
|
# "image": "VizWiz_val_00000001.jpg",
|
||||||
|
# "question": "Can you tell me what this medicine is please?",
|
||||||
|
# "answers": [
|
||||||
|
# {
|
||||||
|
# "answer": "no",
|
||||||
|
# "answer_confidence": "yes"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "unanswerable",
|
||||||
|
# "answer_confidence": "yes"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "night time",
|
||||||
|
# "answer_confidence": "maybe"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "unanswerable",
|
||||||
|
# "answer_confidence": "yes"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "night time",
|
||||||
|
# "answer_confidence": "maybe"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "night time cold medicine",
|
||||||
|
# "answer_confidence": "maybe"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "night time",
|
||||||
|
# "answer_confidence": "maybe"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "night time",
|
||||||
|
# "answer_confidence": "maybe"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "night time",
|
||||||
|
# "answer_confidence": "maybe"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "answer": "night time medicine",
|
||||||
|
# "answer_confidence": "yes"
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
# "answer_type": "other",
|
||||||
|
# "answerable": 1
|
||||||
|
# },
|
||||||
|
data_info = dict()
|
||||||
|
data_info['question'] = ann['question']
|
||||||
|
data_info['img_path'] = mmengine.join_path(
|
||||||
|
self.data_prefix['img_path'], ann['image'])
|
||||||
|
|
||||||
|
if 'answerable' not in ann:
|
||||||
|
data_list.append(data_info)
|
||||||
|
else:
|
||||||
|
if ann['answerable'] == 1:
|
||||||
|
# add answer_weight & answer_count, delete duplicate answer
|
||||||
|
answers = []
|
||||||
|
for item in ann.pop('answers'):
|
||||||
|
if item['answer_confidence'] == 'yes' and item[
|
||||||
|
'answer'] != 'unanswerable':
|
||||||
|
answers.append(item['answer'])
|
||||||
|
count = Counter(answers)
|
||||||
|
answer_weight = [i / len(answers) for i in count.values()]
|
||||||
|
data_info['gt_answer'] = list(count.keys())
|
||||||
|
data_info['gt_answer_weight'] = answer_weight
|
||||||
|
# data_info.update(ann)
|
||||||
|
data_list.append(data_info)
|
||||||
|
|
||||||
|
return data_list
|
Loading…
x
Reference in New Issue
Block a user