[Feature] Support NoCap dataset based on BLIP. (#1582)
* [Feature] Support nocaps dataset * precommit * Use official coco format * add nocp readme * fix readme --------- Co-authored-by: mzr1996 <mzr1996@163.com>pull/1637/head
parent
46a523ef63
commit
a779c8c5a7
|
@ -0,0 +1,41 @@
|
|||
# data settings
|
||||
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(384, 384),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='NoCaps',
|
||||
data_root='data/nocaps/',
|
||||
data_prefix=dict(img_path='images/'),
|
||||
ann_file='annotations/nocaps_val_4500_captions.json',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_evaluator = dict(
|
||||
type='NocapsSave',
|
||||
save_dir='./',
|
||||
)
|
||||
|
||||
# # If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -46,6 +46,12 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
|
|||
| :----------------------------- | :--------: | :----: | :----: | :------------------------------------: | :------------------------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_caption`\* | 223.97 | 40.12 | 132.82 | [config](./blip-base_8xb32_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
|
||||
|
||||
### Image Caption on NoCaps
|
||||
|
||||
| Model | Params (M) | SPICE | CIDER | Config | Download |
|
||||
| :----------------------------- | :--------: | :---: | :----: | :----------------------------------: | :---------------------------------------------------------------------------------------------------------------: |
|
||||
| `blip-base_3rdparty_caption`\* | 223.97 | 14.69 | 109.12 | [config](./blip-base_8x32_nocaps.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |
|
||||
|
||||
### Visual Grounding on RefCOCO
|
||||
|
||||
| Model | Params (M) | Accuracy (testA) | Accuracy (testB) | Config | Download |
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/nocaps.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='BlipCaption',
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='b',
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
out_type='raw',
|
||||
),
|
||||
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
|
||||
decoder_head=dict(
|
||||
type='SeqGenerationHead',
|
||||
decoder=dict(
|
||||
type='XBertLMHeadDecoder',
|
||||
med_config=dict(
|
||||
architectures=['BertModel'],
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
layer_norm_eps=1e-12,
|
||||
max_position_embeddings=512,
|
||||
model_type='bert',
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=12,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=False,
|
||||
vocab_size=30524,
|
||||
encoder_width=768,
|
||||
add_cross_attention=True),
|
||||
),
|
||||
),
|
||||
prompt='a picture of ',
|
||||
max_txt_len=20,
|
||||
)
|
||||
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
|
@ -39,6 +39,7 @@ if WITH_MULTIMODAL:
|
|||
from .coco_vqa import COCOVQA
|
||||
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
|
||||
from .gqa_dataset import GQA
|
||||
from .nocaps import NoCaps
|
||||
from .refcoco import RefCOCO
|
||||
from .scienceqa import ScienceQA
|
||||
from .visual_genome import VisualGenomeQA
|
||||
|
@ -52,5 +53,6 @@ if WITH_MULTIMODAL:
|
|||
'RefCOCO',
|
||||
'VisualGenomeQA',
|
||||
'ScienceQA',
|
||||
'NoCaps'
|
||||
'GQA',
|
||||
])
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine.dataset import BaseDataset
|
||||
from mmengine.fileio import get_file_backend
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class NoCaps(BaseDataset):
|
||||
"""NoCaps dataset.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``data_prefix`` and
|
||||
``ann_file``..
|
||||
ann_file (str): Annotation file path.
|
||||
data_prefix (dict): Prefix for data field. Defaults to
|
||||
``dict(img_path='')``.
|
||||
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
img_prefix = self.data_prefix['img_path']
|
||||
with mmengine.get_local_path(self.ann_file) as ann_file:
|
||||
coco = COCO(ann_file)
|
||||
|
||||
file_backend = get_file_backend(img_prefix)
|
||||
data_list = []
|
||||
for ann in coco.anns.values():
|
||||
image_id = ann['image_id']
|
||||
image_path = file_backend.join_path(
|
||||
img_prefix, coco.imgs[image_id]['file_name'])
|
||||
data_info = {
|
||||
'image_id': image_id,
|
||||
'img_path': image_path,
|
||||
'gt_caption': None
|
||||
}
|
||||
|
||||
data_list.append(data_info)
|
||||
|
||||
return data_list
|
|
@ -3,6 +3,7 @@ from .caption import COCOCaption
|
|||
from .gqa import GQAAcc
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
from .multi_task import MultiTasksMetric
|
||||
from .nocaps import NocapsSave
|
||||
from .retrieval import RetrievalRecall
|
||||
from .scienceqa import ScienceQAMetric
|
||||
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
|
||||
|
@ -14,5 +15,5 @@ __all__ = [
|
|||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
|
||||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
|
||||
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
|
||||
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc'
|
||||
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
import mmengine
|
||||
|
||||
from mmpretrain.registry import METRICS
|
||||
from mmpretrain.utils import require
|
||||
from .caption import COCOCaption, save_result
|
||||
|
||||
try:
|
||||
from pycocoevalcap.eval import COCOEvalCap
|
||||
from pycocotools.coco import COCO
|
||||
except ImportError:
|
||||
COCOEvalCap = None
|
||||
COCO = None
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class NocapsSave(COCOCaption):
|
||||
"""Nocaps evaluation wrapper.
|
||||
|
||||
Save the generated captions and transform into coco format.
|
||||
The dumped file can be submitted to the official evluation system.
|
||||
|
||||
Args:
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Should be modified according to the
|
||||
`retrieval_type` for unambiguous results. Defaults to TR.
|
||||
"""
|
||||
|
||||
@require('pycocoevalcap')
|
||||
def __init__(self,
|
||||
save_dir: str = './',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None):
|
||||
super(COCOCaption, self).__init__(
|
||||
collect_device=collect_device, prefix=prefix)
|
||||
self.save_dir = save_dir
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (dict): The processed results of each batch.
|
||||
"""
|
||||
mmengine.mkdir_or_exist(self.save_dir)
|
||||
save_result(
|
||||
result=results,
|
||||
result_dir=self.save_dir,
|
||||
filename='nocap_pred',
|
||||
remove_duplicate='image_id',
|
||||
)
|
||||
|
||||
return dict()
|
Loading…
Reference in New Issue