106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from collections import Counter
|
|
from typing import List
|
|
|
|
import mmengine
|
|
from mmengine.dataset import BaseDataset
|
|
|
|
from mmpretrain.registry import DATASETS
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class TextVQA(BaseDataset):
|
|
"""TextVQA dataset.
|
|
|
|
val image:
|
|
https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
|
|
test image:
|
|
https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
|
|
val json:
|
|
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
|
|
test json:
|
|
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json
|
|
|
|
folder structure:
|
|
data/textvqa
|
|
├── annotations
|
|
│ ├── TextVQA_0.5.1_test.json
|
|
│ └── TextVQA_0.5.1_val.json
|
|
└── images
|
|
├── test_images
|
|
└── train_images
|
|
|
|
Args:
|
|
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
|
and ``question_file``.
|
|
data_prefix (str): The directory of images.
|
|
question_file (str): Question file path.
|
|
ann_file (str, optional): Annotation file path for training and
|
|
validation. Defaults to an empty string.
|
|
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
|
"""
|
|
|
|
def __init__(self,
|
|
data_root: str,
|
|
data_prefix: str,
|
|
ann_file: str = '',
|
|
**kwarg):
|
|
super().__init__(
|
|
data_root=data_root,
|
|
data_prefix=dict(img_path=data_prefix),
|
|
ann_file=ann_file,
|
|
**kwarg,
|
|
)
|
|
|
|
def load_data_list(self) -> List[dict]:
|
|
"""Load data list."""
|
|
annotations = mmengine.load(self.ann_file)['data']
|
|
|
|
data_list = []
|
|
|
|
for ann in annotations:
|
|
|
|
# ann example
|
|
# {
|
|
# 'question': 'what is the brand of...is camera?',
|
|
# 'image_id': '003a8ae2ef43b901',
|
|
# 'image_classes': [
|
|
# 'Cassette deck', 'Printer', ...
|
|
# ],
|
|
# 'flickr_original_url': 'https://farm2.static...04a6_o.jpg',
|
|
# 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg',
|
|
# 'image_width': 1024,
|
|
# 'image_height': 664,
|
|
# 'answers': [
|
|
# 'nous les gosses',
|
|
# 'dakota',
|
|
# 'clos culombu',
|
|
# 'dakota digital' ...
|
|
# ],
|
|
# 'question_tokens':
|
|
# ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'],
|
|
# 'question_id': 34602,
|
|
# 'set_name': 'val'
|
|
# }
|
|
|
|
data_info = dict(question=ann['question'])
|
|
data_info['question_id'] = ann['question_id']
|
|
data_info['image_id'] = ann['image_id']
|
|
|
|
img_path = mmengine.join_path(self.data_prefix['img_path'],
|
|
ann['image_id'] + '.jpg')
|
|
data_info['img_path'] = img_path
|
|
|
|
data_info['question_id'] = ann['question_id']
|
|
|
|
if 'answers' in ann:
|
|
answers = [item for item in ann.pop('answers')]
|
|
count = Counter(answers)
|
|
answer_weight = [i / len(answers) for i in count.values()]
|
|
data_info['gt_answer'] = list(count.keys())
|
|
data_info['gt_answer_weight'] = answer_weight
|
|
|
|
data_list.append(data_info)
|
|
|
|
return data_list
|