[Feature] support TextVQA dataset (#1596)

* [Support] Suport TextVQA dataset

* add folder structure

* fix readme
This commit is contained in:
Wangbo Zhao(黑色枷锁) 2023-06-02 11:50:38 +08:00 committed by GitHub
parent bc3c4a35ee
commit 3a277ee9e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 12 deletions

View File

@ -48,9 +48,9 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
### Image Caption on NoCaps
| Model | Params (M) | SPICE | CIDER | Config | Download |
| :----------------------------- | :--------: | :---: | :----: | :----------------------------------: | :---------------------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_caption`\* | 223.97 | 14.69 | 109.12 | [config](./blip-base_8x32_nocaps.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
| Model | Params (M) | SPICE | CIDER | Config | Download |
| :----------------------------- | :--------: | :---: | :----: | :-----------------------------------: | :--------------------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_caption`\* | 223.97 | 14.69 | 109.12 | [config](./blip-base_8xb32_nocaps.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
### Visual Grounding on RefCOCO

View File

@ -42,17 +42,12 @@ if WITH_MULTIMODAL:
from .nocaps import NoCaps
from .refcoco import RefCOCO
from .scienceqa import ScienceQA
from .textvqa import TextVQA
from .visual_genome import VisualGenomeQA
__all__.extend([
'COCOCaption',
'COCORetrieval',
'COCOVQA',
'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA',
'RefCOCO',
'VisualGenomeQA',
'ScienceQA',
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA',
'NoCaps'
'GQA',
'GQA', 'TextVQA'
])

View File

@ -0,0 +1,105 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class TextVQA(BaseDataset):
"""TextVQA dataset.
val image:
https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
test image:
https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
val json:
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
test json:
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json
folder structure:
data/textvqa
annotations
TextVQA_0.5.1_test.json
TextVQA_0.5.1_val.json
images
test_images
train_images
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
question_file (str): Question file path.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)['data']
data_list = []
for ann in annotations:
# ann example
# {
# 'question': 'what is the brand of...is camera?',
# 'image_id': '003a8ae2ef43b901',
# 'image_classes': [
# 'Cassette deck', 'Printer', ...
# ],
# 'flickr_original_url': 'https://farm2.static...04a6_o.jpg',
# 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg',
# 'image_width': 1024,
# 'image_height': 664,
# 'answers': [
# 'nous les gosses',
# 'dakota',
# 'clos culombu',
# 'dakota digital' ...
# ],
# 'question_tokens':
# ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'],
# 'question_id': 34602,
# 'set_name': 'val'
# }
data_info = dict(question=ann['question'])
data_info['question_id'] = ann['question_id']
data_info['image_id'] = ann['image_id']
img_path = mmengine.join_path(self.data_prefix['img_path'],
ann['image_id'] + '.jpg')
data_info['img_path'] = img_path
data_info['question_id'] = ann['question_id']
if 'answers' in ann:
answers = [item for item in ann.pop('answers')]
count = Counter(answers)
answer_weight = [i / len(answers) for i in count.values()]
data_info['gt_answer'] = list(count.keys())
data_info['gt_answer_weight'] = answer_weight
data_list.append(data_info)
return data_list