[Feature] Support OCR-VQA dataset (#1621)
* support ocrvqa dataset * minor * remove abs path * refine READMEpull/1644/head
parent
dbfb84ccbd
commit
bb415b91be
|
@ -0,0 +1,81 @@
|
|||
# data settings
|
||||
|
||||
data_preprocessor = dict(
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CleanCaption', keys=['question', 'gt_answer']),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=[],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(480, 480),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CleanCaption', keys=['question', 'gt_answer']),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=[],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='OCRVQA',
|
||||
data_root='data/ocrvqa',
|
||||
data_prefix='images',
|
||||
ann_file='annotations/dataset.json',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='OCRVQA',
|
||||
data_root='data/ocrvqa',
|
||||
data_prefix='images',
|
||||
ann_file='annotations/dataset.json',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='VQAAcc')
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='OCRVQA',
|
||||
data_root='data/ocrvqa',
|
||||
data_prefix='images',
|
||||
ann_file='annotations/dataset.json',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='VQAAcc')
|
|
@ -70,6 +70,12 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
|
|||
| :------------------------- | :--------: | :------: | :----------------------------------: | :-------------------------------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_vqa`\* | 361.48 | 40.59# | [config](./blip-base_8xb32_okvqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty-capflit_vqa_20230505-81488941.pth) |
|
||||
|
||||
### Visual Question Answering on OCR-VQA
|
||||
|
||||
| Model | Params (M) | Accuracy | Config | Download |
|
||||
| :------------------------- | :--------: | :------: | :-----------------------------------: | :-------------------------------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_vqa`\* | 361.48 | 28.30# | [config](./blip-base_8xb32_ocrvqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty-capflit_vqa_20230505-81488941.pth) |
|
||||
|
||||
### Image-To-Text Retrieval on COCO
|
||||
|
||||
| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/ocrvqa.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='BlipVQA',
|
||||
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='b',
|
||||
img_size=480,
|
||||
patch_size=16,
|
||||
out_type='raw'),
|
||||
multimodal_backbone=dict(
|
||||
type='XBertEncoder',
|
||||
med_config=dict(
|
||||
architectures=['BertModel'],
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
layer_norm_eps=1e-12,
|
||||
max_position_embeddings=512,
|
||||
model_type='bert',
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=12,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=False,
|
||||
vocab_size=30524,
|
||||
encoder_width=768,
|
||||
add_cross_attention=True),
|
||||
),
|
||||
head=dict(
|
||||
type='VQAGenerationHead',
|
||||
decoder=dict(
|
||||
type='XBertLMHeadDecoder',
|
||||
med_config=dict(
|
||||
architectures=['BertModel'],
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
layer_norm_eps=1e-12,
|
||||
max_position_embeddings=512,
|
||||
model_type='bert',
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=12,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=False,
|
||||
vocab_size=30524,
|
||||
encoder_width=768,
|
||||
add_cross_attention=True),
|
||||
),
|
||||
inference_method='generate',
|
||||
),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.05)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||
|
||||
param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)]
|
||||
|
||||
train_cfg = dict(max_epochs=10, by_epoch=True)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# runtime settings
|
||||
randomness = dict(seed=42)
|
|
@ -40,6 +40,7 @@ if WITH_MULTIMODAL:
|
|||
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
|
||||
from .gqa_dataset import GQA
|
||||
from .nocaps import NoCaps
|
||||
from .ocr_vqa import OCRVQA
|
||||
from .refcoco import RefCOCO
|
||||
from .scienceqa import ScienceQA
|
||||
from .textvqa import TextVQA
|
||||
|
@ -47,7 +48,6 @@ if WITH_MULTIMODAL:
|
|||
|
||||
__all__.extend([
|
||||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||
'FlamingoEvalCOCOVQA', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA',
|
||||
'NoCaps'
|
||||
'GQA', 'TextVQA'
|
||||
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
|
||||
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA'
|
||||
])
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine.dataset import BaseDataset
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class OCRVQA(BaseDataset):
|
||||
"""OCR-VQA dataset.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
||||
and ``question_file``.
|
||||
data_prefix (str): The directory of images.
|
||||
ann_file (str): Annotation file path for training and validation.
|
||||
split (str): 'train', 'val' or 'test'.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def __init__(self, data_root: str, data_prefix: str, ann_file: str,
|
||||
split: str, **kwarg):
|
||||
|
||||
assert split in ['train', 'val', 'test'], \
|
||||
'`split` must be train, val or test'
|
||||
self.split = split
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path=data_prefix),
|
||||
ann_file=ann_file,
|
||||
**kwarg,
|
||||
)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
|
||||
split_dict = {1: 'train', 2: 'val', 3: 'test'}
|
||||
|
||||
annotations = mmengine.load(self.ann_file)
|
||||
|
||||
# ann example
|
||||
# "761183272": {
|
||||
# "imageURL": \
|
||||
# "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg",
|
||||
# "questions": [
|
||||
# "Who wrote this book?",
|
||||
# "What is the title of this book?",
|
||||
# "What is the genre of this book?",
|
||||
# "Is this a games related book?",
|
||||
# "What is the year printed on this calendar?"],
|
||||
# "answers": [
|
||||
# "Sandra Boynton",
|
||||
# "Mom's Family Wall Calendar 2016",
|
||||
# "Calendars",
|
||||
# "No",
|
||||
# "2016"],
|
||||
# "title": "Mom's Family Wall Calendar 2016",
|
||||
# "authorName": "Sandra Boynton",
|
||||
# "genre": "Calendars",
|
||||
# "split": 1
|
||||
# },
|
||||
|
||||
data_list = []
|
||||
|
||||
for key, ann in annotations.items():
|
||||
if self.split != split_dict[ann['split']]:
|
||||
continue
|
||||
|
||||
extension = osp.splitext(ann['imageURL'])[1]
|
||||
if extension not in ['.jpg', '.png']:
|
||||
continue
|
||||
img_path = mmengine.join_path(self.data_prefix['img_path'],
|
||||
key + extension)
|
||||
for question, answer in zip(ann['questions'], ann['answers']):
|
||||
data_info = {}
|
||||
data_info['img_path'] = img_path
|
||||
data_info['question'] = question
|
||||
data_info['gt_answer'] = answer
|
||||
data_info['gt_answer_weight'] = [1.0]
|
||||
|
||||
data_info['imageURL'] = ann['imageURL']
|
||||
data_info['title'] = ann['title']
|
||||
data_info['authorName'] = ann['authorName']
|
||||
data_info['genre'] = ann['genre']
|
||||
|
||||
data_list.append(data_info)
|
||||
|
||||
return data_list
|
Loading…
Reference in New Issue