[Feature] Support OCR-VQA dataset (#1621)

* support ocrvqa dataset

* minor

* remove abs path

* refine README
pull/1644/head
Yiqin Wang 王逸钦 2023-06-13 10:28:45 +08:00 committed by GitHub
parent dbfb84ccbd
commit bb415b91be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 256 additions and 3 deletions

View File

@ -0,0 +1,81 @@
# data settings
data_preprocessor = dict(
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys=['question', 'gt_answer']),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=[],
),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys=['question', 'gt_answer']),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=[],
),
]
train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='OCRVQA',
data_root='data/ocrvqa',
data_prefix='images',
ann_file='annotations/dataset.json',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=64,
num_workers=8,
dataset=dict(
type='OCRVQA',
data_root='data/ocrvqa',
data_prefix='images',
ann_file='annotations/dataset.json',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='VQAAcc')
test_dataloader = dict(
batch_size=64,
num_workers=8,
dataset=dict(
type='OCRVQA',
data_root='data/ocrvqa',
data_prefix='images',
ann_file='annotations/dataset.json',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='VQAAcc')

View File

@ -70,6 +70,12 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
| :------------------------- | :--------: | :------: | :----------------------------------: | :-------------------------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_vqa`\* | 361.48 | 40.59# | [config](./blip-base_8xb32_okvqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty-capflit_vqa_20230505-81488941.pth) |
### Visual Question Answering on OCR-VQA
| Model | Params (M) | Accuracy | Config | Download |
| :------------------------- | :--------: | :------: | :-----------------------------------: | :-------------------------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_vqa`\* | 361.48 | 28.30# | [config](./blip-base_8xb32_ocrvqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty-capflit_vqa_20230505-81488941.pth) |
### Image-To-Text Retrieval on COCO
| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |

View File

@ -0,0 +1,75 @@
_base_ = [
'../_base_/datasets/ocrvqa.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
type='BlipVQA',
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
vision_backbone=dict(
type='VisionTransformer',
arch='b',
img_size=480,
patch_size=16,
out_type='raw'),
multimodal_backbone=dict(
type='XBertEncoder',
med_config=dict(
architectures=['BertModel'],
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
layer_norm_eps=1e-12,
max_position_embeddings=512,
model_type='bert',
num_attention_heads=12,
num_hidden_layers=12,
pad_token_id=0,
add_type_embeddings=False,
vocab_size=30524,
encoder_width=768,
add_cross_attention=True),
),
head=dict(
type='VQAGenerationHead',
decoder=dict(
type='XBertLMHeadDecoder',
med_config=dict(
architectures=['BertModel'],
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
layer_norm_eps=1e-12,
max_position_embeddings=512,
model_type='bert',
num_attention_heads=12,
num_hidden_layers=12,
pad_token_id=0,
add_type_embeddings=False,
vocab_size=30524,
encoder_width=768,
add_cross_attention=True),
),
inference_method='generate',
),
)
# schedule settings
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.05)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)]
train_cfg = dict(max_epochs=10, by_epoch=True)
val_cfg = dict()
test_cfg = dict()
# runtime settings
randomness = dict(seed=42)

View File

@ -40,6 +40,7 @@ if WITH_MULTIMODAL:
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
from .gqa_dataset import GQA
from .nocaps import NoCaps
from .ocr_vqa import OCRVQA
from .refcoco import RefCOCO
from .scienceqa import ScienceQA
from .textvqa import TextVQA
@ -47,7 +48,6 @@ if WITH_MULTIMODAL:
__all__.extend([
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA',
'NoCaps'
'GQA', 'TextVQA'
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA'
])

View File

@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class OCRVQA(BaseDataset):
"""OCR-VQA dataset.
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
ann_file (str): Annotation file path for training and validation.
split (str): 'train', 'val' or 'test'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self, data_root: str, data_prefix: str, ann_file: str,
split: str, **kwarg):
assert split in ['train', 'val', 'test'], \
'`split` must be train, val or test'
self.split = split
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
split_dict = {1: 'train', 2: 'val', 3: 'test'}
annotations = mmengine.load(self.ann_file)
# ann example
# "761183272": {
# "imageURL": \
# "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg",
# "questions": [
# "Who wrote this book?",
# "What is the title of this book?",
# "What is the genre of this book?",
# "Is this a games related book?",
# "What is the year printed on this calendar?"],
# "answers": [
# "Sandra Boynton",
# "Mom's Family Wall Calendar 2016",
# "Calendars",
# "No",
# "2016"],
# "title": "Mom's Family Wall Calendar 2016",
# "authorName": "Sandra Boynton",
# "genre": "Calendars",
# "split": 1
# },
data_list = []
for key, ann in annotations.items():
if self.split != split_dict[ann['split']]:
continue
extension = osp.splitext(ann['imageURL'])[1]
if extension not in ['.jpg', '.png']:
continue
img_path = mmengine.join_path(self.data_prefix['img_path'],
key + extension)
for question, answer in zip(ann['questions'], ann['answers']):
data_info = {}
data_info['img_path'] = img_path
data_info['question'] = question
data_info['gt_answer'] = answer
data_info['gt_answer_weight'] = [1.0]
data_info['imageURL'] = ann['imageURL']
data_info['title'] = ann['title']
data_info['authorName'] = ann['authorName']
data_info['genre'] = ann['genre']
data_list.append(data_info)
return data_list