[Feature] Support Flickr30k Retrieval dataset (#1625)
* format * remove abs path * init add flickr30k caption * remove abs dir * update blip readme * add convert sscripts * minor * minorpull/1653/head
parent
a1cfe888e2
commit
6d7fe91a98
|
@ -0,0 +1,92 @@
|
||||||
|
# data settings
|
||||||
|
|
||||||
|
data_preprocessor = dict(
|
||||||
|
type='MultiModalDataPreprocessor',
|
||||||
|
mean=[122.770938, 116.7460125, 104.09373615],
|
||||||
|
std=[68.5005327, 66.6321579, 70.32316305],
|
||||||
|
to_rgb=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=384,
|
||||||
|
interpolation='bicubic',
|
||||||
|
backend='pillow'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='CleanCaption', keys='gt_caption'),
|
||||||
|
dict(
|
||||||
|
type='PackInputs',
|
||||||
|
algorithm_keys=['gt_caption'],
|
||||||
|
meta_keys=['image_id'],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='Resize',
|
||||||
|
scale=(384, 384),
|
||||||
|
interpolation='bicubic',
|
||||||
|
backend='pillow'),
|
||||||
|
dict(type='PackInputs', meta_keys=['image_id']),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=32,
|
||||||
|
num_workers=5,
|
||||||
|
dataset=dict(
|
||||||
|
type='Flickr30kCaption',
|
||||||
|
data_root='data/flickr30k',
|
||||||
|
ann_file='annotations/dataset_flickr30k.json',
|
||||||
|
data_prefix='images',
|
||||||
|
split='train',
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
persistent_workers=True,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=5,
|
||||||
|
dataset=dict(
|
||||||
|
type='Flickr30kCaption',
|
||||||
|
data_root='data/flickr30k',
|
||||||
|
ann_file='annotations/dataset_flickr30k.json',
|
||||||
|
data_prefix='images',
|
||||||
|
split='val',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# refer tools/dataset_converters/convert_flickr30k_ann.py
|
||||||
|
val_evaluator = dict(
|
||||||
|
type='COCOCaption',
|
||||||
|
ann_file='data/flickr30k_val_gt.json',
|
||||||
|
)
|
||||||
|
|
||||||
|
# # If you want standard test, please manually configure the test dataset
|
||||||
|
test_dataloader = dict(
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=5,
|
||||||
|
dataset=dict(
|
||||||
|
type='Flickr30kCaption',
|
||||||
|
data_root='data/flickr30k',
|
||||||
|
ann_file='annotations/dataset_flickr30k.json',
|
||||||
|
data_prefix='images',
|
||||||
|
split='test',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# refer tools/dataset_converters/convert_flickr30k_ann.py
|
||||||
|
test_evaluator = dict(
|
||||||
|
type='COCOCaption',
|
||||||
|
ann_file='data/flickr30k_test_gt.json',
|
||||||
|
)
|
|
@ -0,0 +1,112 @@
|
||||||
|
# data settings
|
||||||
|
data_preprocessor = dict(
|
||||||
|
type='MultiModalDataPreprocessor',
|
||||||
|
mean=[122.770938, 116.7460125, 104.09373615],
|
||||||
|
std=[68.5005327, 66.6321579, 70.32316305],
|
||||||
|
to_rgb=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
rand_increasing_policies = [
|
||||||
|
dict(type='AutoContrast'),
|
||||||
|
dict(type='Equalize'),
|
||||||
|
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
|
||||||
|
dict(
|
||||||
|
type='Brightness', magnitude_key='magnitude',
|
||||||
|
magnitude_range=(0, 0.0)),
|
||||||
|
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)),
|
||||||
|
dict(
|
||||||
|
type='Shear',
|
||||||
|
magnitude_key='magnitude',
|
||||||
|
magnitude_range=(0, 0.3),
|
||||||
|
direction='horizontal'),
|
||||||
|
dict(
|
||||||
|
type='Shear',
|
||||||
|
magnitude_key='magnitude',
|
||||||
|
magnitude_range=(0, 0.3),
|
||||||
|
direction='vertical'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=384,
|
||||||
|
crop_ratio_range=(0.5, 1.0),
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(
|
||||||
|
type='RandAugment',
|
||||||
|
policies=rand_increasing_policies,
|
||||||
|
num_policies=2,
|
||||||
|
magnitude_level=5),
|
||||||
|
dict(type='CleanCaption', keys='text'),
|
||||||
|
dict(
|
||||||
|
type='PackInputs',
|
||||||
|
algorithm_keys=['text', 'is_matched'],
|
||||||
|
meta_keys=['image_id']),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='Resize',
|
||||||
|
scale=(384, 384),
|
||||||
|
interpolation='bicubic',
|
||||||
|
backend='pillow'),
|
||||||
|
dict(type='CleanCaption', keys='text'),
|
||||||
|
dict(
|
||||||
|
type='PackInputs',
|
||||||
|
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'],
|
||||||
|
meta_keys=['image_id']),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=32,
|
||||||
|
num_workers=16,
|
||||||
|
dataset=dict(
|
||||||
|
type='Flickr30kRetrieval',
|
||||||
|
data_root='data/flickr30k',
|
||||||
|
ann_file='annotations/dataset_flickr30k.json',
|
||||||
|
data_prefix='images',
|
||||||
|
split='train',
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
persistent_workers=True,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=64,
|
||||||
|
num_workers=16,
|
||||||
|
dataset=dict(
|
||||||
|
type='Flickr30kRetrieval',
|
||||||
|
data_root='data/flickr30k',
|
||||||
|
ann_file='annotations/dataset_flickr30k.json',
|
||||||
|
data_prefix='images',
|
||||||
|
split='val',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
test_mode=True, # This is required for evaluation
|
||||||
|
),
|
||||||
|
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10))
|
||||||
|
|
||||||
|
# If you want standard test, please manually configure the test dataset
|
||||||
|
test_dataloader = dict(
|
||||||
|
batch_size=64,
|
||||||
|
num_workers=16,
|
||||||
|
dataset=dict(
|
||||||
|
type='Flickr30kRetrieval',
|
||||||
|
data_root='data/flickr30k',
|
||||||
|
ann_file='annotations/dataset_flickr30k.json',
|
||||||
|
data_prefix='images',
|
||||||
|
split='test',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
test_mode=True, # This is required for evaluation
|
||||||
|
),
|
||||||
|
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
test_evaluator = val_evaluator
|
|
@ -52,6 +52,12 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
|
||||||
| :----------------------------- | :--------: | :---: | :----: | :-----------------------------------: | :--------------------------------------------------------------------------------------------------------------: |
|
| :----------------------------- | :--------: | :---: | :----: | :-----------------------------------: | :--------------------------------------------------------------------------------------------------------------: |
|
||||||
| `blip-base_3rdparty_caption`\* | 223.97 | 14.69 | 109.12 | [config](./blip-base_8xb32_nocaps.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
|
| `blip-base_3rdparty_caption`\* | 223.97 | 14.69 | 109.12 | [config](./blip-base_8xb32_nocaps.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
|
||||||
|
|
||||||
|
### Image Caption on Flickr30k
|
||||||
|
|
||||||
|
| Model | Params (M) | SPICE | CIDER | Config | Download |
|
||||||
|
| :----------------------------- | :--------: | :---: | :---: | :----------------------------------------------: | :----------------------------------------------------------------------------------------------------: |
|
||||||
|
| `blip-base_3rdparty_caption`\* | 223.97 | 15.58 | 68.89 | [config](./blip-base_8xb32_caption_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
|
||||||
|
|
||||||
### Visual Grounding on RefCOCO
|
### Visual Grounding on RefCOCO
|
||||||
|
|
||||||
| Model | Params (M) | Accuracy (testA) | Accuracy (testB) | Config | Download |
|
| Model | Params (M) | Accuracy (testA) | Accuracy (testB) | Config | Download |
|
||||||
|
@ -88,6 +94,18 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
|
||||||
| :------------------------------- | :--------: | :------: | :------: | :--------------------------------------: | :----------------------------------------------------------------------------------------------------: |
|
| :------------------------------- | :--------: | :------: | :------: | :--------------------------------------: | :----------------------------------------------------------------------------------------------------: |
|
||||||
| `blip-base_3rdparty_retrieval`\* | 447.49 | 64.82 | 86.28 | [config](./blip-base_8xb32_retrieval.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |
|
| `blip-base_3rdparty_retrieval`\* | 447.49 | 64.82 | 86.28 | [config](./blip-base_8xb32_retrieval.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |
|
||||||
|
|
||||||
|
### Image-To-Text Retrieval on Flickr30k
|
||||||
|
|
||||||
|
| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
|
||||||
|
| :------------------------------- | :--------: | :------: | :------: | :------------------------------------------------: | :------------------------------------------------------------------------------------------: |
|
||||||
|
| `blip-base_3rdparty_retrieval`\* | 447.49 | 95.10# | 99.60# | [config](./blip-base_8xb32_retrieval_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |
|
||||||
|
|
||||||
|
### Text-To-Image Retrieval on Flickr30k
|
||||||
|
|
||||||
|
| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
|
||||||
|
| :------------------------------- | :--------: | :------: | :------: | :------------------------------------------------: | :------------------------------------------------------------------------------------------: |
|
||||||
|
| `blip-base_3rdparty_retrieval`\* | 447.49 | 85.26# | 96.58# | [config](./blip-base_8xb32_retrieval_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |
|
||||||
|
|
||||||
### NLVR on NLVR2
|
### NLVR on NLVR2
|
||||||
|
|
||||||
| Model | Params (M) | Top-1 (%) | Config | Download |
|
| Model | Params (M) | Top-1 (%) | Config | Download |
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/datasets/flickr30k_caption.py',
|
||||||
|
'../_base_/default_runtime.py',
|
||||||
|
]
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='BlipCaption',
|
||||||
|
vision_encoder=dict(
|
||||||
|
type='VisionTransformer',
|
||||||
|
arch='b',
|
||||||
|
img_size=384,
|
||||||
|
patch_size=16,
|
||||||
|
out_type='raw',
|
||||||
|
),
|
||||||
|
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
|
||||||
|
decoder_head=dict(
|
||||||
|
type='SeqGenerationHead',
|
||||||
|
decoder=dict(
|
||||||
|
type='XBertLMHeadDecoder',
|
||||||
|
med_config=dict(
|
||||||
|
architectures=['BertModel'],
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
hidden_act='gelu',
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
hidden_size=768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
intermediate_size=3072,
|
||||||
|
layer_norm_eps=1e-12,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
model_type='bert',
|
||||||
|
num_attention_heads=12,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
pad_token_id=0,
|
||||||
|
add_type_embeddings=False,
|
||||||
|
vocab_size=30524,
|
||||||
|
encoder_width=768,
|
||||||
|
add_cross_attention=True),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
prompt='a picture of ',
|
||||||
|
max_txt_len=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# schedule settings
|
||||||
|
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))
|
||||||
|
|
||||||
|
param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type='CosineAnnealingLR',
|
||||||
|
by_epoch=True,
|
||||||
|
begin=0,
|
||||||
|
end=10,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
train_cfg = dict(max_epochs=10)
|
||||||
|
val_cfg = dict()
|
||||||
|
test_cfg = dict()
|
|
@ -0,0 +1,83 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/datasets/flickr30k_retrieval.py',
|
||||||
|
'../_base_/default_runtime.py',
|
||||||
|
]
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='BlipRetrieval',
|
||||||
|
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
|
||||||
|
vision_backbone=dict(
|
||||||
|
type='VisionTransformer',
|
||||||
|
arch='b',
|
||||||
|
img_size=384,
|
||||||
|
patch_size=16,
|
||||||
|
out_type='raw',
|
||||||
|
),
|
||||||
|
text_backbone=dict(
|
||||||
|
type='XBertEncoder',
|
||||||
|
med_config=dict(
|
||||||
|
architectures=['BertModel'],
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
hidden_act='gelu',
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
hidden_size=768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
intermediate_size=3072,
|
||||||
|
layer_norm_eps=1e-12,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
model_type='bert',
|
||||||
|
num_attention_heads=12,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
pad_token_id=0,
|
||||||
|
add_type_embeddings=False,
|
||||||
|
vocab_size=30524,
|
||||||
|
encoder_width=768,
|
||||||
|
add_cross_attention=True),
|
||||||
|
),
|
||||||
|
vision_neck=dict(
|
||||||
|
type='Linear',
|
||||||
|
in_features=768,
|
||||||
|
out_features=256,
|
||||||
|
),
|
||||||
|
text_neck=dict(
|
||||||
|
type='Linear',
|
||||||
|
in_features=768,
|
||||||
|
out_features=256,
|
||||||
|
),
|
||||||
|
head=dict(
|
||||||
|
type='ITCHead',
|
||||||
|
embed_dim=256,
|
||||||
|
),
|
||||||
|
multimodal_head=dict(
|
||||||
|
type='ITMHead',
|
||||||
|
hidden_size=768,
|
||||||
|
with_pooler=False,
|
||||||
|
),
|
||||||
|
topk=256,
|
||||||
|
max_txt_len=35,
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.04)
|
||||||
|
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||||
|
|
||||||
|
# learning rate scheduler
|
||||||
|
param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)]
|
||||||
|
|
||||||
|
# runtime settings
|
||||||
|
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=6)
|
||||||
|
val_cfg = dict(type='RetrievalValLoop')
|
||||||
|
test_cfg = dict(type='RetrievalTestLoop')
|
||||||
|
|
||||||
|
randomness = dict(seed=42)
|
||||||
|
|
||||||
|
default_hooks = dict(logger=dict(interval=1))
|
||||||
|
|
||||||
|
custom_hooks = [
|
||||||
|
dict(
|
||||||
|
type='WarmupParamHook',
|
||||||
|
param_name='alpha',
|
||||||
|
module_name='head',
|
||||||
|
warmup_epochs=2)
|
||||||
|
]
|
|
@ -38,6 +38,8 @@ if WITH_MULTIMODAL:
|
||||||
from .coco_retrieval import COCORetrieval
|
from .coco_retrieval import COCORetrieval
|
||||||
from .coco_vqa import COCOVQA
|
from .coco_vqa import COCOVQA
|
||||||
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
|
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
|
||||||
|
from .flickr30k_caption import Flickr30kCaption
|
||||||
|
from .flickr30k_retrieval import Flickr30kRetrieval
|
||||||
from .gqa_dataset import GQA
|
from .gqa_dataset import GQA
|
||||||
from .nocaps import NoCaps
|
from .nocaps import NoCaps
|
||||||
from .ocr_vqa import OCRVQA
|
from .ocr_vqa import OCRVQA
|
||||||
|
@ -50,6 +52,7 @@ if WITH_MULTIMODAL:
|
||||||
|
|
||||||
__all__.extend([
|
__all__.extend([
|
||||||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||||
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
|
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
|
||||||
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR', 'VizWiz'
|
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
|
||||||
|
'VSR', 'VizWiz', 'OCRVQA'
|
||||||
])
|
])
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
from mmengine.dataset import BaseDataset
|
||||||
|
from mmengine.fileio import get_file_backend
|
||||||
|
|
||||||
|
from mmpretrain.registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class Flickr30kCaption(BaseDataset):
|
||||||
|
"""Flickr30k Caption dataset. To generate coco-style GT annotation for
|
||||||
|
evaluation, please refer to
|
||||||
|
tools/dataset_converters/convert_flickr30k_ann.py.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
||||||
|
and ``question_file``.
|
||||||
|
data_prefix (str): The directory of images.
|
||||||
|
ann_file (str): Annotation file path for training and validation.
|
||||||
|
split (str): 'train', 'val' or 'test'.
|
||||||
|
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_root: str, data_prefix: str, ann_file: str,
|
||||||
|
split: str, **kwarg):
|
||||||
|
|
||||||
|
assert split in ['train', 'val', 'test'], \
|
||||||
|
'`split` must be train, val or test'
|
||||||
|
self.split = split
|
||||||
|
super().__init__(
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path=data_prefix),
|
||||||
|
ann_file=ann_file,
|
||||||
|
**kwarg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_data_list(self) -> List[dict]:
|
||||||
|
"""Load data list."""
|
||||||
|
img_prefix = self.data_prefix['img_path']
|
||||||
|
annotations = mmengine.load(self.ann_file)
|
||||||
|
file_backend = get_file_backend(img_prefix)
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
|
||||||
|
for img in annotations['images']:
|
||||||
|
|
||||||
|
# img_example={
|
||||||
|
# "sentids": [0, 1, 2],
|
||||||
|
# "imgid": 0,
|
||||||
|
# "sentences": [
|
||||||
|
# {"raw": "Two men in green shirts standing in a yard.",
|
||||||
|
# "imgid": 0, "sentid": 0},
|
||||||
|
# {"raw": "A man in a blue shirt standing in a garden.",
|
||||||
|
# "imgid": 0, "sentid": 1},
|
||||||
|
# {"raw": "Two friends enjoy time spent together.",
|
||||||
|
# "imgid": 0, "sentid": 2}
|
||||||
|
# ],
|
||||||
|
# "split": "train",
|
||||||
|
# "filename": "1000092795.jpg"
|
||||||
|
# },
|
||||||
|
|
||||||
|
if img['split'] != self.split:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for sentence in img['sentences']:
|
||||||
|
data_info = {
|
||||||
|
'image_id': img['imgid'],
|
||||||
|
'img_path': file_backend.join_path(img_prefix,
|
||||||
|
img['filename']),
|
||||||
|
'gt_caption': sentence['raw']
|
||||||
|
}
|
||||||
|
|
||||||
|
data_list.append(data_info)
|
||||||
|
|
||||||
|
return data_list
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
from mmengine import get_file_backend
|
||||||
|
|
||||||
|
from mmpretrain.registry import DATASETS
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class Flickr30kRetrieval(BaseDataset):
|
||||||
|
"""Flickr30k Retrieval dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
||||||
|
and ``question_file``.
|
||||||
|
data_prefix (str): The directory of images.
|
||||||
|
ann_file (str): Annotation file path for training and validation.
|
||||||
|
split (str): 'train', 'val' or 'test'.
|
||||||
|
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_root: str, data_prefix: str, ann_file: str,
|
||||||
|
split: str, **kwarg):
|
||||||
|
|
||||||
|
assert split in ['train', 'val', 'test'], \
|
||||||
|
'`split` must be train, val or test'
|
||||||
|
self.split = split
|
||||||
|
super().__init__(
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path=data_prefix),
|
||||||
|
ann_file=ann_file,
|
||||||
|
**kwarg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_data_list(self) -> List[dict]:
|
||||||
|
"""Load data list."""
|
||||||
|
# get file backend
|
||||||
|
img_prefix = self.data_prefix['img_path']
|
||||||
|
file_backend = get_file_backend(img_prefix)
|
||||||
|
|
||||||
|
annotations = mmengine.load(self.ann_file)
|
||||||
|
|
||||||
|
# mapping img_id to img filename
|
||||||
|
img_dict = OrderedDict()
|
||||||
|
img_idx = 0
|
||||||
|
sentence_idx = 0
|
||||||
|
train_list = []
|
||||||
|
for img in annotations['images']:
|
||||||
|
|
||||||
|
# img_example={
|
||||||
|
# "sentids": [0, 1, 2],
|
||||||
|
# "imgid": 0,
|
||||||
|
# "sentences": [
|
||||||
|
# {"raw": "Two men in green shirts standing in a yard.",
|
||||||
|
# "imgid": 0, "sentid": 0},
|
||||||
|
# {"raw": "A man in a blue shirt standing in a garden.",
|
||||||
|
# "imgid": 0, "sentid": 1},
|
||||||
|
# {"raw": "Two friends enjoy time spent together.",
|
||||||
|
# "imgid": 0, "sentid": 2}
|
||||||
|
# ],
|
||||||
|
# "split": "train",
|
||||||
|
# "filename": "1000092795.jpg"
|
||||||
|
# },
|
||||||
|
|
||||||
|
if img['split'] != self.split:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# create new idx for image
|
||||||
|
train_image = dict(
|
||||||
|
ori_id=img['imgid'],
|
||||||
|
image_id=img_idx, # used for evaluation
|
||||||
|
img_path=file_backend.join_path(img_prefix, img['filename']),
|
||||||
|
text=[],
|
||||||
|
gt_text_id=[],
|
||||||
|
gt_image_id=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
for sentence in img['sentences']:
|
||||||
|
ann = {}
|
||||||
|
ann['text'] = sentence['raw']
|
||||||
|
ann['ori_id'] = sentence['sentid']
|
||||||
|
ann['text_id'] = sentence_idx # used for evaluation
|
||||||
|
|
||||||
|
ann['image_ori_id'] = train_image['ori_id']
|
||||||
|
ann['image_id'] = train_image['image_id']
|
||||||
|
ann['img_path'] = train_image['img_path']
|
||||||
|
ann['is_matched'] = True
|
||||||
|
|
||||||
|
# 1. prepare train data list item
|
||||||
|
train_list.append(ann)
|
||||||
|
# 2. prepare eval data list item based on img dict
|
||||||
|
train_image['text'].append(ann['text'])
|
||||||
|
train_image['gt_text_id'].append(ann['text_id'])
|
||||||
|
train_image['gt_image_id'].append(ann['image_id'])
|
||||||
|
|
||||||
|
sentence_idx += 1
|
||||||
|
|
||||||
|
img_dict[img['imgid']] = train_image
|
||||||
|
img_idx += 1
|
||||||
|
|
||||||
|
self.img_size = len(img_dict)
|
||||||
|
self.text_size = len(train_list)
|
||||||
|
|
||||||
|
# return needed format data list
|
||||||
|
if self.test_mode:
|
||||||
|
return list(img_dict.values())
|
||||||
|
return train_list
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
"""Create COCO-Style GT annotations based on raw annotation of Flickr30k.
|
||||||
|
|
||||||
|
GT annotations are used for evaluation in image caption task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
with open('dataset_flickr30k.json', 'r') as f:
|
||||||
|
annotations = json.load(f)
|
||||||
|
ann_list = []
|
||||||
|
img_list = []
|
||||||
|
splits = ['train', 'val', 'test']
|
||||||
|
for split in splits:
|
||||||
|
for img in annotations['images']:
|
||||||
|
|
||||||
|
# img_example={
|
||||||
|
# "sentids": [0, 1, 2],
|
||||||
|
# "imgid": 0,
|
||||||
|
# "sentences": [
|
||||||
|
# {"raw": "Two men in green shirts standing in a yard.",
|
||||||
|
# "imgid": 0, "sentid": 0},
|
||||||
|
# {"raw": "A man in a blue shirt standing in a garden.",
|
||||||
|
# "imgid": 0, "sentid": 1},
|
||||||
|
# {"raw": "Two friends enjoy time spent together.",
|
||||||
|
# "imgid": 0, "sentid": 2}
|
||||||
|
# ],
|
||||||
|
# "split": "train",
|
||||||
|
# "filename": "1000092795.jpg"
|
||||||
|
# },
|
||||||
|
|
||||||
|
if img['split'] != split:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_list.append({'id': img['imgid']})
|
||||||
|
|
||||||
|
for sentence in img['sentences']:
|
||||||
|
ann_info = {
|
||||||
|
'image_id': img['imgid'],
|
||||||
|
'id': sentence['sentid'],
|
||||||
|
'caption': sentence['raw']
|
||||||
|
}
|
||||||
|
ann_list.append(ann_info)
|
||||||
|
|
||||||
|
json_file = {'annotations': ann_list, 'images': img_list}
|
||||||
|
|
||||||
|
# generate flickr30k_train_gt.json, flickr30k_val_gt.json
|
||||||
|
# and flickr30k_test_gt.json
|
||||||
|
with open(f'flickr30k_{split}_gt.json', 'w') as f:
|
||||||
|
json.dump(json_file, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue