[Feature] Support Flickr30k Retrieval dataset (#1625)
* format * remove abs path * init add flickr30k caption * remove abs dir * update blip readme * add convert sscripts * minor * minorpull/1653/head
parent
a1cfe888e2
commit
6d7fe91a98
|
@ -0,0 +1,92 @@
|
|||
# data settings
|
||||
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='CleanCaption', keys='gt_caption'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['gt_caption'],
|
||||
meta_keys=['image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(384, 384),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='Flickr30kCaption',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='Flickr30kCaption',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='val',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
# refer tools/dataset_converters/convert_flickr30k_ann.py
|
||||
val_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/flickr30k_val_gt.json',
|
||||
)
|
||||
|
||||
# # If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='Flickr30kCaption',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='test',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
# refer tools/dataset_converters/convert_flickr30k_ann.py
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/flickr30k_test_gt.json',
|
||||
)
|
|
@ -0,0 +1,112 @@
|
|||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
rand_increasing_policies = [
|
||||
dict(type='AutoContrast'),
|
||||
dict(type='Equalize'),
|
||||
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
|
||||
dict(
|
||||
type='Brightness', magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.0)),
|
||||
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)),
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
direction='horizontal'),
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
direction='vertical'),
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
crop_ratio_range=(0.5, 1.0),
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies=rand_increasing_policies,
|
||||
num_policies=2,
|
||||
magnitude_level=5),
|
||||
dict(type='CleanCaption', keys='text'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text', 'is_matched'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(384, 384),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CleanCaption', keys='text'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type='Flickr30kRetrieval',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type='Flickr30kRetrieval',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='val',
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True, # This is required for evaluation
|
||||
),
|
||||
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type='Flickr30kRetrieval',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='test',
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True, # This is required for evaluation
|
||||
),
|
||||
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
|
||||
persistent_workers=True,
|
||||
)
|
||||
test_evaluator = val_evaluator
|
|
@ -52,6 +52,12 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
|
|||
| :----------------------------- | :--------: | :---: | :----: | :-----------------------------------: | :--------------------------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_caption`\* | 223.97 | 14.69 | 109.12 | [config](./blip-base_8xb32_nocaps.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
|
||||
|
||||
### Image Caption on Flickr30k
|
||||
|
||||
| Model | Params (M) | SPICE | CIDER | Config | Download |
|
||||
| :----------------------------- | :--------: | :---: | :---: | :----------------------------------------------: | :----------------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_caption`\* | 223.97 | 15.58 | 68.89 | [config](./blip-base_8xb32_caption_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
|
||||
|
||||
### Visual Grounding on RefCOCO
|
||||
|
||||
| Model | Params (M) | Accuracy (testA) | Accuracy (testB) | Config | Download |
|
||||
|
@ -88,6 +94,18 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
|
|||
| :------------------------------- | :--------: | :------: | :------: | :--------------------------------------: | :----------------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_retrieval`\* | 447.49 | 64.82 | 86.28 | [config](./blip-base_8xb32_retrieval.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |
|
||||
|
||||
### Image-To-Text Retrieval on Flickr30k
|
||||
|
||||
| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
|
||||
| :------------------------------- | :--------: | :------: | :------: | :------------------------------------------------: | :------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_retrieval`\* | 447.49 | 95.10# | 99.60# | [config](./blip-base_8xb32_retrieval_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |
|
||||
|
||||
### Text-To-Image Retrieval on Flickr30k
|
||||
|
||||
| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
|
||||
| :------------------------------- | :--------: | :------: | :------: | :------------------------------------------------: | :------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_retrieval`\* | 447.49 | 85.26# | 96.58# | [config](./blip-base_8xb32_retrieval_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |
|
||||
|
||||
### NLVR on NLVR2
|
||||
|
||||
| Model | Params (M) | Top-1 (%) | Config | Download |
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/flickr30k_caption.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='BlipCaption',
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='b',
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
out_type='raw',
|
||||
),
|
||||
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
|
||||
decoder_head=dict(
|
||||
type='SeqGenerationHead',
|
||||
decoder=dict(
|
||||
type='XBertLMHeadDecoder',
|
||||
med_config=dict(
|
||||
architectures=['BertModel'],
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
layer_norm_eps=1e-12,
|
||||
max_position_embeddings=512,
|
||||
model_type='bert',
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=12,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=False,
|
||||
vocab_size=30524,
|
||||
encoder_width=768,
|
||||
add_cross_attention=True),
|
||||
),
|
||||
),
|
||||
prompt='a picture of ',
|
||||
max_txt_len=20,
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
)
|
||||
]
|
||||
|
||||
train_cfg = dict(max_epochs=10)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
|
@ -0,0 +1,83 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/flickr30k_retrieval.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='BlipRetrieval',
|
||||
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='b',
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
out_type='raw',
|
||||
),
|
||||
text_backbone=dict(
|
||||
type='XBertEncoder',
|
||||
med_config=dict(
|
||||
architectures=['BertModel'],
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
layer_norm_eps=1e-12,
|
||||
max_position_embeddings=512,
|
||||
model_type='bert',
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=12,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=False,
|
||||
vocab_size=30524,
|
||||
encoder_width=768,
|
||||
add_cross_attention=True),
|
||||
),
|
||||
vision_neck=dict(
|
||||
type='Linear',
|
||||
in_features=768,
|
||||
out_features=256,
|
||||
),
|
||||
text_neck=dict(
|
||||
type='Linear',
|
||||
in_features=768,
|
||||
out_features=256,
|
||||
),
|
||||
head=dict(
|
||||
type='ITCHead',
|
||||
embed_dim=256,
|
||||
),
|
||||
multimodal_head=dict(
|
||||
type='ITMHead',
|
||||
hidden_size=768,
|
||||
with_pooler=False,
|
||||
),
|
||||
topk=256,
|
||||
max_txt_len=35,
|
||||
)
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.04)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=6)
|
||||
val_cfg = dict(type='RetrievalValLoop')
|
||||
test_cfg = dict(type='RetrievalTestLoop')
|
||||
|
||||
randomness = dict(seed=42)
|
||||
|
||||
default_hooks = dict(logger=dict(interval=1))
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='WarmupParamHook',
|
||||
param_name='alpha',
|
||||
module_name='head',
|
||||
warmup_epochs=2)
|
||||
]
|
|
@ -38,6 +38,8 @@ if WITH_MULTIMODAL:
|
|||
from .coco_retrieval import COCORetrieval
|
||||
from .coco_vqa import COCOVQA
|
||||
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
|
||||
from .flickr30k_caption import Flickr30kCaption
|
||||
from .flickr30k_retrieval import Flickr30kRetrieval
|
||||
from .gqa_dataset import GQA
|
||||
from .nocaps import NoCaps
|
||||
from .ocr_vqa import OCRVQA
|
||||
|
@ -50,6 +52,7 @@ if WITH_MULTIMODAL:
|
|||
|
||||
__all__.extend([
|
||||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
|
||||
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR', 'VizWiz'
|
||||
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
|
||||
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
|
||||
'VSR', 'VizWiz', 'OCRVQA'
|
||||
])
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine.dataset import BaseDataset
|
||||
from mmengine.fileio import get_file_backend
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class Flickr30kCaption(BaseDataset):
|
||||
"""Flickr30k Caption dataset. To generate coco-style GT annotation for
|
||||
evaluation, please refer to
|
||||
tools/dataset_converters/convert_flickr30k_ann.py.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
||||
and ``question_file``.
|
||||
data_prefix (str): The directory of images.
|
||||
ann_file (str): Annotation file path for training and validation.
|
||||
split (str): 'train', 'val' or 'test'.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def __init__(self, data_root: str, data_prefix: str, ann_file: str,
|
||||
split: str, **kwarg):
|
||||
|
||||
assert split in ['train', 'val', 'test'], \
|
||||
'`split` must be train, val or test'
|
||||
self.split = split
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path=data_prefix),
|
||||
ann_file=ann_file,
|
||||
**kwarg,
|
||||
)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
img_prefix = self.data_prefix['img_path']
|
||||
annotations = mmengine.load(self.ann_file)
|
||||
file_backend = get_file_backend(img_prefix)
|
||||
|
||||
data_list = []
|
||||
|
||||
for img in annotations['images']:
|
||||
|
||||
# img_example={
|
||||
# "sentids": [0, 1, 2],
|
||||
# "imgid": 0,
|
||||
# "sentences": [
|
||||
# {"raw": "Two men in green shirts standing in a yard.",
|
||||
# "imgid": 0, "sentid": 0},
|
||||
# {"raw": "A man in a blue shirt standing in a garden.",
|
||||
# "imgid": 0, "sentid": 1},
|
||||
# {"raw": "Two friends enjoy time spent together.",
|
||||
# "imgid": 0, "sentid": 2}
|
||||
# ],
|
||||
# "split": "train",
|
||||
# "filename": "1000092795.jpg"
|
||||
# },
|
||||
|
||||
if img['split'] != self.split:
|
||||
continue
|
||||
|
||||
for sentence in img['sentences']:
|
||||
data_info = {
|
||||
'image_id': img['imgid'],
|
||||
'img_path': file_backend.join_path(img_prefix,
|
||||
img['filename']),
|
||||
'gt_caption': sentence['raw']
|
||||
}
|
||||
|
||||
data_list.append(data_info)
|
||||
|
||||
return data_list
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine import get_file_backend
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class Flickr30kRetrieval(BaseDataset):
|
||||
"""Flickr30k Retrieval dataset.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``data_prefix``, ``ann_file``
|
||||
and ``question_file``.
|
||||
data_prefix (str): The directory of images.
|
||||
ann_file (str): Annotation file path for training and validation.
|
||||
split (str): 'train', 'val' or 'test'.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def __init__(self, data_root: str, data_prefix: str, ann_file: str,
|
||||
split: str, **kwarg):
|
||||
|
||||
assert split in ['train', 'val', 'test'], \
|
||||
'`split` must be train, val or test'
|
||||
self.split = split
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path=data_prefix),
|
||||
ann_file=ann_file,
|
||||
**kwarg,
|
||||
)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
# get file backend
|
||||
img_prefix = self.data_prefix['img_path']
|
||||
file_backend = get_file_backend(img_prefix)
|
||||
|
||||
annotations = mmengine.load(self.ann_file)
|
||||
|
||||
# mapping img_id to img filename
|
||||
img_dict = OrderedDict()
|
||||
img_idx = 0
|
||||
sentence_idx = 0
|
||||
train_list = []
|
||||
for img in annotations['images']:
|
||||
|
||||
# img_example={
|
||||
# "sentids": [0, 1, 2],
|
||||
# "imgid": 0,
|
||||
# "sentences": [
|
||||
# {"raw": "Two men in green shirts standing in a yard.",
|
||||
# "imgid": 0, "sentid": 0},
|
||||
# {"raw": "A man in a blue shirt standing in a garden.",
|
||||
# "imgid": 0, "sentid": 1},
|
||||
# {"raw": "Two friends enjoy time spent together.",
|
||||
# "imgid": 0, "sentid": 2}
|
||||
# ],
|
||||
# "split": "train",
|
||||
# "filename": "1000092795.jpg"
|
||||
# },
|
||||
|
||||
if img['split'] != self.split:
|
||||
continue
|
||||
|
||||
# create new idx for image
|
||||
train_image = dict(
|
||||
ori_id=img['imgid'],
|
||||
image_id=img_idx, # used for evaluation
|
||||
img_path=file_backend.join_path(img_prefix, img['filename']),
|
||||
text=[],
|
||||
gt_text_id=[],
|
||||
gt_image_id=[],
|
||||
)
|
||||
|
||||
for sentence in img['sentences']:
|
||||
ann = {}
|
||||
ann['text'] = sentence['raw']
|
||||
ann['ori_id'] = sentence['sentid']
|
||||
ann['text_id'] = sentence_idx # used for evaluation
|
||||
|
||||
ann['image_ori_id'] = train_image['ori_id']
|
||||
ann['image_id'] = train_image['image_id']
|
||||
ann['img_path'] = train_image['img_path']
|
||||
ann['is_matched'] = True
|
||||
|
||||
# 1. prepare train data list item
|
||||
train_list.append(ann)
|
||||
# 2. prepare eval data list item based on img dict
|
||||
train_image['text'].append(ann['text'])
|
||||
train_image['gt_text_id'].append(ann['text_id'])
|
||||
train_image['gt_image_id'].append(ann['image_id'])
|
||||
|
||||
sentence_idx += 1
|
||||
|
||||
img_dict[img['imgid']] = train_image
|
||||
img_idx += 1
|
||||
|
||||
self.img_size = len(img_dict)
|
||||
self.text_size = len(train_list)
|
||||
|
||||
# return needed format data list
|
||||
if self.test_mode:
|
||||
return list(img_dict.values())
|
||||
return train_list
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Create COCO-Style GT annotations based on raw annotation of Flickr30k.
|
||||
|
||||
GT annotations are used for evaluation in image caption task.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def main():
|
||||
with open('dataset_flickr30k.json', 'r') as f:
|
||||
annotations = json.load(f)
|
||||
ann_list = []
|
||||
img_list = []
|
||||
splits = ['train', 'val', 'test']
|
||||
for split in splits:
|
||||
for img in annotations['images']:
|
||||
|
||||
# img_example={
|
||||
# "sentids": [0, 1, 2],
|
||||
# "imgid": 0,
|
||||
# "sentences": [
|
||||
# {"raw": "Two men in green shirts standing in a yard.",
|
||||
# "imgid": 0, "sentid": 0},
|
||||
# {"raw": "A man in a blue shirt standing in a garden.",
|
||||
# "imgid": 0, "sentid": 1},
|
||||
# {"raw": "Two friends enjoy time spent together.",
|
||||
# "imgid": 0, "sentid": 2}
|
||||
# ],
|
||||
# "split": "train",
|
||||
# "filename": "1000092795.jpg"
|
||||
# },
|
||||
|
||||
if img['split'] != split:
|
||||
continue
|
||||
|
||||
img_list.append({'id': img['imgid']})
|
||||
|
||||
for sentence in img['sentences']:
|
||||
ann_info = {
|
||||
'image_id': img['imgid'],
|
||||
'id': sentence['sentid'],
|
||||
'caption': sentence['raw']
|
||||
}
|
||||
ann_list.append(ann_info)
|
||||
|
||||
json_file = {'annotations': ann_list, 'images': img_list}
|
||||
|
||||
# generate flickr30k_train_gt.json, flickr30k_val_gt.json
|
||||
# and flickr30k_test_gt.json
|
||||
with open(f'flickr30k_{split}_gt.json', 'w') as f:
|
||||
json.dump(json_file, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue