mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* [Feat] Migrate blip caption to mmpretrain. (#50) * Migrate blip caption to mmpretrain * minor fix * support train * [Feature] Support OFA caption task. (#51) * [Feature] Support OFA caption task. * Remove duplicated files. * [Feature] Support OFA vqa task. (#58) * [Feature] Support OFA vqa task. * Fix lint. * [Feat] Add BLIP retrieval to mmpretrain. (#55) * init * minor fix for train * fix according to comments * refactor * Update Blip retrieval. (#62) * [Feature] Support OFA visual grounding task. (#59) * [Feature] Support OFA visual grounding task. * minor add TODO --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Add flamingos coco caption and vqa. (#60) * first init * init flamingo coco * add vqa * minor fix * remove unnecessary modules * Update config * Use `ApplyToList`. --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 coco retrieval (#53) * [Feature]: Add blip2 retriever * [Feature]: Add blip2 all modules * [Feature]: Refine model * [Feature]: x1 * [Feature]: Runnable coco ret * [Feature]: Runnable version * [Feature]: Fix lint * [Fix]: Fix lint * [Feature]: Use 364 img size * [Feature]: Refactor blip2 * [Fix]: Fix lint * refactor files * minor fix * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Remove * fix blip caption inputs (#68) * [Feat] Add BLIP NLVR support. (#67) * first init * init flamingo coco * add vqa * add nlvr * refactor nlvr * minor fix * minor fix * Update dataset --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 Caption (#70) * [Feature]: Add language model * [Feature]: blip2 caption forward * [Feature]: Reproduce the results * [Feature]: Refactor caption * refine config --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Migrate BLIP VQA to mmpretrain (#69) * reformat * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * refactor code --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Update RefCOCO dataset * [Fix] fix lint * [Feature] Implement inference APIs for multi-modal tasks. (#65) * [Feature] Implement inference APIs for multi-modal tasks. * [Project] Add gradio demo. * [Improve] Update requirements * Update flamingo * Update blip * Add NLVR inferencer * Update flamingo * Update hugging face model register * Update ofa vqa * Update BLIP-vqa (#71) * Update blip-vqa docstring (#72) * Refine flamingo docstring (#73) * [Feature]: BLIP2 VQA (#61) * [Feature]: VQA forward * [Feature]: Reproduce accuracy * [Fix]: Fix lint * [Fix]: Add blank line * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feature]: BLIP2 docstring (#74) * [Feature]: Add caption docstring * [Feature]: Add docstring to blip2 vqa * [Feature]: Add docstring to retrieval * Update BLIP-2 metafile and README (#75) * [Feature]: Add readme and docstring * Update blip2 results --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] BLIP Visual Grounding on MMPretrain Branch (#66) * blip grounding merge with mmpretrain * remove commit * blip grounding test and inference api * refcoco dataset * refcoco dataset refine config * rebasing * gitignore * rebasing * minor edit * minor edit * Update blip-vqa docstring (#72) * rebasing * Revert "minor edit" This reverts commit 639cec757c215e654625ed0979319e60f0be9044. * blip grounding final * precommit * refine config * refine config * Update blip visual grounding --------- Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: mzr1996 <mzr1996@163.com> * Update visual grounding metric * Update OFA docstring, README and metafiles. (#76) * [Docs] Update installation docs and gradio demo docs. (#77) * Update OFA name * Update Visual Grounding Visualizer * Integrate accelerate support * Fix imports. * Fix timm backbone * Update imports * Update README * Update circle ci * Update flamingo config * Add gradio demo README * [Feature]: Add scienceqa (#1571) * [Feature]: Add scienceqa * [Feature]: Change param name * Update docs * Update video --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: yingfhu <yingfhu@gmail.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: Rongjie Li <limo97@163.com>
321 lines
12 KiB
Python
321 lines
12 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import string
|
|
from collections import defaultdict
|
|
from functools import partial
|
|
from typing import Optional, Union
|
|
|
|
import mmengine
|
|
import torch
|
|
from mmengine.model import BaseModel
|
|
|
|
from mmpretrain.datasets import CleanCaption
|
|
from mmpretrain.registry import MODELS, TOKENIZER
|
|
from mmpretrain.structures import DataSample
|
|
from .ofa_modules import OFAEncoderDecoder
|
|
|
|
|
|
class TreeNode():
|
|
|
|
def __init__(self):
|
|
self.child = defaultdict(TreeNode)
|
|
|
|
|
|
class Trie:
|
|
|
|
def __init__(self, eos):
|
|
self.root = TreeNode()
|
|
self.eos = eos
|
|
|
|
def insert(self, word):
|
|
cur = self.root
|
|
for c in word:
|
|
cur = cur.child[c]
|
|
|
|
def get_next_layer(self, word):
|
|
cur = self.root
|
|
for c in word:
|
|
cur = cur.child.get(c)
|
|
if cur is None:
|
|
return [self.eos]
|
|
return list(cur.child.keys())
|
|
|
|
|
|
def apply_constraint(
|
|
input_ids: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
decoder_prompts: Optional[list],
|
|
num_beams: int,
|
|
constraint_trie: Trie = None,
|
|
):
|
|
if decoder_prompts is None and constraint_trie is None:
|
|
return logits
|
|
|
|
mask = logits.new_zeros(logits[:, -1, :].size(), dtype=torch.bool)
|
|
input_ids = input_ids.view(-1, num_beams, input_ids.shape[-1])
|
|
for batch_id, beam_sent in enumerate(input_ids):
|
|
for beam_id, sent in enumerate(beam_sent):
|
|
if decoder_prompts is None:
|
|
prompt_len = 0
|
|
else:
|
|
prompt_len = len(decoder_prompts[batch_id])
|
|
|
|
if sent.size(0) - 1 < prompt_len:
|
|
allowed_tokens = [decoder_prompts[batch_id][sent.size(0) - 1]]
|
|
mask[batch_id * num_beams + beam_id, allowed_tokens] = True
|
|
elif constraint_trie is not None:
|
|
answer_tokens = [0] + sent[prompt_len + 1:].tolist()
|
|
allowed_tokens = constraint_trie.get_next_layer(answer_tokens)
|
|
mask[batch_id * num_beams + beam_id, allowed_tokens] = True
|
|
else:
|
|
mask[batch_id * num_beams + beam_id, :] = True
|
|
logits[:, -1, :].masked_fill_(~mask, float('-inf'))
|
|
return logits
|
|
|
|
|
|
@MODELS.register_module()
|
|
class OFA(BaseModel):
|
|
"""The OFA model for multiple tasks.
|
|
|
|
Args:
|
|
encoder_cfg (dict): The config of the encoder, accept the keyword
|
|
arguments of :class:`OFAEncoder`.
|
|
decoder_cfg (dict): The config of the decoder, accept the keyword
|
|
arguments of :class:`OFADecoder`.
|
|
vocab_size (int): The size of the vocabulary.
|
|
embedding_dim (int): The embedding dimensions of both the encoder
|
|
and the decoder.
|
|
tokenizer (dict | PreTrainedTokenizer): The tokenizer to encode
|
|
the text.
|
|
task (str): The task name, supported tasks are "caption", "vqa" and
|
|
"refcoco".
|
|
prompt (str, optional): The prompt template for the following tasks,
|
|
If None, use default prompt:
|
|
|
|
- **caption**: ' what does the image describe?'
|
|
- **refcoco**: ' which region does the text " {} " describe?'
|
|
|
|
Defaults to None
|
|
ans2label (str | Sequence | None): The answer to label mapping for
|
|
the vqa task. If a string, it should be a pickle or json file.
|
|
The sequence constrains the output answers. Defaults to None,
|
|
which means no constraint.
|
|
generation_cfg (dict): The extra generation config, accept the keyword
|
|
arguments of :class:`~transformers.GenerationConfig`.
|
|
Defaults to an empty dict.
|
|
data_preprocessor (dict, optional): The config for preprocessing input
|
|
data. If None or no specified type, it will use
|
|
"MultiModalDataPreprocessor" as type. See :class:
|
|
`MultiModalDataPreprocessor` for more details. Defaults to None.
|
|
init_cfg (dict, optional): The initialization config. Defaults to None.
|
|
"""
|
|
support_tasks = {'caption', 'vqa', 'refcoco'}
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_cfg,
|
|
decoder_cfg,
|
|
vocab_size,
|
|
embedding_dim,
|
|
tokenizer,
|
|
task,
|
|
prompt=None,
|
|
ans2label: Union[dict, str, None] = None,
|
|
generation_cfg=dict(),
|
|
data_preprocessor: Optional[dict] = None,
|
|
init_cfg=None,
|
|
):
|
|
if data_preprocessor is None:
|
|
data_preprocessor = {}
|
|
if isinstance(data_preprocessor, dict):
|
|
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
|
|
data_preprocessor = MODELS.build(data_preprocessor)
|
|
|
|
super().__init__(
|
|
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
|
|
|
if isinstance(tokenizer, dict):
|
|
self.tokenizer = TOKENIZER.build(tokenizer)
|
|
else:
|
|
self.tokenizer = tokenizer
|
|
|
|
if task not in self.support_tasks:
|
|
raise ValueError(f'Unsupported task {task}, please select '
|
|
f'the task from {self.support_tasks}.')
|
|
|
|
self.prompt = prompt
|
|
self.task = task
|
|
|
|
if isinstance(ans2label, str):
|
|
self.ans2label = mmengine.load(ans2label)
|
|
else:
|
|
self.ans2label = ans2label
|
|
|
|
if self.task == 'vqa' and self.ans2label is not None:
|
|
self.constraint_trie = Trie(eos=self.tokenizer.eos_token_id)
|
|
answers = [f' {answer}' for answer in self.ans2label]
|
|
answer_tokens = self.tokenizer(answers, padding=False)
|
|
for answer_token in answer_tokens['input_ids']:
|
|
self.constraint_trie.insert(answer_token)
|
|
else:
|
|
self.constraint_trie = None
|
|
|
|
generation_cfg = {
|
|
'num_beams': 5,
|
|
'max_new_tokens': 20,
|
|
'no_repeat_ngram_size': 3,
|
|
**generation_cfg,
|
|
}
|
|
self.model = OFAEncoderDecoder(
|
|
encoder_cfg=encoder_cfg,
|
|
decoder_cfg=decoder_cfg,
|
|
padding_idx=self.tokenizer.pad_token_id,
|
|
vocab_size=vocab_size,
|
|
embedding_dim=embedding_dim,
|
|
generation_cfg=generation_cfg,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
images: torch.Tensor,
|
|
data_samples: Optional[list] = None,
|
|
mode: str = 'predict',
|
|
**kwargs,
|
|
):
|
|
"""The unified entry for a forward process in both training and test.
|
|
The method accepts the following modes:
|
|
|
|
- "predict": Forward and return a list of data samples contain the
|
|
predict results.
|
|
|
|
Args:
|
|
images (torch.Tensor): the preprocessed image tensor of shape
|
|
``(N, C, H, W)``.
|
|
data_samples (List[DataSample], optional): The annotation data
|
|
of every samples. Defaults to None.
|
|
mode (str): Return what kind of value. Defaults to 'predict'.
|
|
"""
|
|
if mode == 'predict':
|
|
return self.predict(images, data_samples, **kwargs)
|
|
else:
|
|
raise RuntimeError(f'Invalid mode "{mode}".')
|
|
|
|
def predict(
|
|
self,
|
|
images,
|
|
data_samples=None,
|
|
post_process=True,
|
|
**generation_config,
|
|
):
|
|
text_tokens = self.preprocess_text(data_samples, images.size(0),
|
|
images.device)
|
|
|
|
if 'images_mask' in data_samples[0]:
|
|
images_mask = torch.tensor([
|
|
sample.get('images_mask') for sample in data_samples
|
|
]).bool().to(images.device)
|
|
else:
|
|
images_mask = None
|
|
|
|
num_beams = generation_config.get(
|
|
'num_beams', getattr(self.model.generation_config, 'num_beams'))
|
|
decoder_prompts = self.get_decoder_prompts(data_samples)
|
|
constrain_fn = partial(
|
|
apply_constraint,
|
|
constraint_trie=self.constraint_trie,
|
|
decoder_prompts=decoder_prompts,
|
|
num_beams=num_beams,
|
|
)
|
|
|
|
outputs = self.model.generate(
|
|
input_ids=text_tokens,
|
|
images=images,
|
|
images_mask=images_mask,
|
|
constrain_fn=constrain_fn,
|
|
**generation_config,
|
|
)
|
|
|
|
if decoder_prompts is not None:
|
|
# Remove the prefix decoder prompt.
|
|
for prompt_ids, token in zip(decoder_prompts, outputs):
|
|
token[1:len(prompt_ids) + 1] = self.tokenizer.pad_token_id
|
|
|
|
if post_process:
|
|
return self.post_process(outputs, data_samples)
|
|
else:
|
|
return outputs
|
|
|
|
def get_decoder_prompts(self, data_samples):
|
|
decoder_prompts = []
|
|
if 'decoder_prompt' not in data_samples[0]:
|
|
return None
|
|
for sample in data_samples:
|
|
prompt = ' ' + sample.get('decoder_prompt')
|
|
prompt_ids = self.tokenizer(prompt, add_special_tokens=False)
|
|
prompt_ids = prompt_ids['input_ids']
|
|
decoder_prompts.append(prompt_ids)
|
|
return decoder_prompts
|
|
|
|
def preprocess_text(self, data_samples, batch_size, device):
|
|
if self.task == 'caption':
|
|
prompt = self.prompt or ' what does the image describe?'
|
|
prompts = [prompt] * batch_size
|
|
prompts = self.tokenizer(prompts, return_tensors='pt')
|
|
return prompts.input_ids.to(device)
|
|
elif self.task == 'vqa':
|
|
prompts = []
|
|
for sample in data_samples:
|
|
assert 'question' in sample
|
|
prompt = ' ' + sample.get('question')
|
|
prompts.append(prompt)
|
|
prompts = self.tokenizer(
|
|
prompts, return_tensors='pt', padding=True)
|
|
return prompts.input_ids.to(device)
|
|
elif self.task == 'refcoco':
|
|
prompt_template = self.prompt or \
|
|
' which region does the text " {} " describe?'
|
|
prompts = []
|
|
for sample in data_samples:
|
|
assert 'text' in sample
|
|
prompt = prompt_template.format(sample.get('text'))
|
|
prompts.append(prompt)
|
|
prompts = self.tokenizer(
|
|
prompts, return_tensors='pt', padding=True)
|
|
return prompts.input_ids.to(device)
|
|
|
|
def post_process(self, outputs, data_samples):
|
|
|
|
out_data_samples = []
|
|
if data_samples is None:
|
|
data_samples = [None] * outputs.size(0)
|
|
|
|
for data_sample, token in zip(data_samples, outputs):
|
|
if data_sample is None:
|
|
data_sample = DataSample()
|
|
|
|
if self.task == 'caption':
|
|
text = self.tokenizer.decode(token, skip_special_tokens=True)
|
|
text = CleanCaption(
|
|
lowercase=False,
|
|
remove_chars=string.punctuation).clean(text)
|
|
data_sample.pred_caption = text
|
|
elif self.task == 'vqa':
|
|
text = self.tokenizer.decode(token, skip_special_tokens=True)
|
|
data_sample.pred_answer = text.strip()
|
|
elif self.task == 'refcoco':
|
|
bbox = token[1:5] - self.tokenizer.bin_offset
|
|
# During training, the bbox is normalized by 512. It's related
|
|
# to the `max_image_size` config in the official repo.
|
|
bbox = bbox / self.tokenizer.num_bins * 512
|
|
scale_factor = data_sample.get('scale_factor', (1, 1))
|
|
bbox[0::2] /= scale_factor[0]
|
|
bbox[1::2] /= scale_factor[1]
|
|
data_sample.pred_bboxes = bbox.unsqueeze(0)
|
|
if 'gt_bboxes' in data_sample:
|
|
gt_bboxes = bbox.new_tensor(data_sample.gt_bboxes)
|
|
gt_bboxes[:, 0::2] /= scale_factor[0]
|
|
gt_bboxes[:, 1::2] /= scale_factor[1]
|
|
data_sample.gt_bboxes = gt_bboxes
|
|
out_data_samples.append(data_sample)
|
|
|
|
return out_data_samples
|