add model file
parent
16ac01aa88
commit
4260679d33
|
@ -0,0 +1,78 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/coco_caption.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(224, 224),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='MplugOwlForConditionalGeneration',
|
||||
vision_encoder=dict(
|
||||
type='MplugOwlVisionModel',
|
||||
hidden_size=1024,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
intermediate_size=4096,
|
||||
num_attention_heads=16,
|
||||
attention_dropout=0.0,
|
||||
layer_norm_eps=1e-6,
|
||||
num_hidden_layers=24,
|
||||
pretrained= # noqa
|
||||
'' # noqa
|
||||
),
|
||||
abstractor_model=dict(
|
||||
type='MplugOwlVisualAbstractorModel',
|
||||
language_hidden_size=4096,
|
||||
num_hidden_layers=6,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
layer_norm_eps=1e-6,
|
||||
encoder_hidden_size=1024,
|
||||
pretrained= # noqa
|
||||
'' # noqa
|
||||
),
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_LLAMA'),
|
||||
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_LLAMA'),
|
||||
task='caption',
|
||||
prompt_template="The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\nHuman: <image>\nHuman: how many cats are there?\nAI: ",
|
||||
# raw_prompts=[
|
||||
# '<Img><ImageHere></Img> Describe this image in detail.',
|
||||
# '<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
|
||||
# '<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
|
||||
# '<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa
|
||||
# ],
|
||||
# max_txt_len=160,
|
||||
# end_sym='###'
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
)
|
||||
]
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=5)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .mplug_owl import mPLUGOwl
|
||||
from .mplug_owl import MplugOwl
|
||||
|
||||
__all__ = ['mPLUGOwl']
|
||||
__all__ = ['MplugOwl']
|
||||
|
|
|
@ -165,7 +165,6 @@ class MplugOwlVisionAttention(BaseModel):
|
|||
return outputs
|
||||
|
||||
|
||||
QuickGLUE = MODELS.bulid(dict(type="QuickGELU"))
|
||||
|
||||
|
||||
class MplugOwlMLP(BaseModel):
|
||||
|
@ -318,8 +317,8 @@ class MplugOwlVisionEncoder(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MplugOwlVisionModel(BaseModel):
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, intermediate_size=4096, num_attention_heads=16, attention_dropout=0.0,layer_norm_eps=1e-6, num_hidden_layers=24):
|
||||
super().__init__()
|
||||
|
@ -660,10 +659,12 @@ class MplugOwlVisualAbstractorEncoder(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MplugOwlVisualAbstractorModel(BaseModel):
|
||||
def __init__(self, language_hidden_size, num_hidden_layers=6, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
|
||||
super().__init__()
|
||||
|
||||
self.language_hidden_size = language_hidden_size
|
||||
self.encoder = MplugOwlVisualAbstractorEncoder(num_hidden_layers, hidden_size,num_attention_heads,intermediate_size,attention_probs_dropout_prob,layer_norm_eps,encoder_hidden_size)
|
||||
self.visual_fc = torch.nn.Linear(hidden_size, language_hidden_size)
|
||||
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
|
||||
|
@ -816,35 +817,32 @@ class MplugOwlVisualAbstractorModel(BaseModel):
|
|||
sequence_output = self.visual_fc(sequence_output)
|
||||
sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(sequence_output.shape[0], 1, 1)], dim=1)
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
return (
|
||||
sequence_output,
|
||||
pooled_output,
|
||||
encoder_outputs.hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MplugOwlModel(BaseModel):
|
||||
main_input_name = "pixel_values"
|
||||
class MplugOwl(BaseModel):
|
||||
def __init__(self, vision_encoder, abstractor_model, lang_encoder, num_query_tokens=64):
|
||||
super().__init__()
|
||||
|
||||
def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.vision_model = MODELS.build(vision_encoder)
|
||||
|
||||
self.vision_model = MplugOwlVisionModel(config.vision_config)
|
||||
self.abstractor = MODELS.build(abstractor_model)
|
||||
|
||||
self.query_tokens = nn.Parameter(
|
||||
torch.zeros(1, config.num_query_tokens, config.visual_abstractor_config.hidden_size)
|
||||
)
|
||||
self.abstractor = MplugOwlVisualAbstractorModel(
|
||||
config.visual_abstractor_config, config.text_config.hidden_size
|
||||
torch.zeros(1, num_query_tokens, self.abstractor.language_hidden_size)
|
||||
)
|
||||
|
||||
# if config.use_decoder_only_language_model:
|
||||
# from llama.modeling_llama import LlamaForCausalLM
|
||||
language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
# language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
# else:
|
||||
# language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||
self.language_model = language_model
|
||||
self.language_model = MODELS.build(lang_encoder)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
@ -939,144 +937,327 @@ class MplugOwlModel(BaseModel):
|
|||
return vision_outputs
|
||||
|
||||
|
||||
# Hack for bloomz
|
||||
def bloom_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
class MplugOwlForConditionalGeneration(BaseModel):
|
||||
|
||||
def __init__(self, vision_encoder, abstractor_model, lang_encoder,lang_tokenizer,num_query_tokens=64):
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = MODELS.build(vision_encoder)
|
||||
|
||||
self.abstractor = MODELS.build(abstractor_model)
|
||||
|
||||
self.query_tokens = nn.Parameter(
|
||||
torch.zeros(1, num_query_tokens, self.abstractor.language_hidden_size)
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
self.language_model = MODELS.build(lang_encoder)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
self.language_tokenizer = TOKENIZER.build(lang_tokenizer)
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.main_input_name = "input_ids"
|
||||
from transformers import GenerationConfig
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
self.generation_config = GenerationConfig(
|
||||
max_length=512, do_sample=True, top_k=3, pad_token_id=0, unk_token_id=0, bos_token_id=1, eos_token_id=2
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
inputs_embeds = self.word_embeddings_layernorm(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
def get_encoder(self):
|
||||
return self.language_model.get_encoder()
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
def _tie_weights(self):
|
||||
if not self.config.use_decoder_only_language_model:
|
||||
self.language_model.encoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.decoder.embed_tokens = self.language_model.shared
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
def _preprocess_accelerate(self):
|
||||
r"""
|
||||
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
||||
https://github.com/huggingface/transformers/pull/21707 for more details.
|
||||
"""
|
||||
hf_device_map = self.hf_device_map
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
layer_past,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
||||
# warn users about unexpected behavior when using multi-GPU + mPLUG-Owl + `accelerate`.
|
||||
MMLogger.warning(
|
||||
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
||||
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
||||
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
||||
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
||||
" more details on creating a `device_map` for large models.",
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.FloatTensor,
|
||||
num_images,
|
||||
non_padding_mask: Optional[torch.LongTensor] = None,
|
||||
non_media_mask: Optional[torch.LongTensor] = None,
|
||||
prompt_mask: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
SFT example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import MplugOwlProcessor, MplugOwlForConditionalGeneration
|
||||
>>> import torch
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
>>> processor = MplugOwlProcessor.from_pretrained("MAGAer13/mplug-owl-llama-7b")
|
||||
>>> model = MplugOwlForConditionalGeneration.from_pretrained(
|
||||
... "MAGAer13/mplug-owl-llama-7b", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> prompt = [
|
||||
... "The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\nHuman: <image>\nHuman: how many cats are there?\nAI: "
|
||||
... ]
|
||||
>>> inputs = processor(images=[image], text=prompt, return_tensors="pt").to(device, torch.float16)
|
||||
|
||||
>>> generated_ids = model.generate(**inputs)
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
>>> print(generated_text)
|
||||
There are two cats in the image.
|
||||
```"""
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(self.vision_model.embeddings.cls_token.data.dtype)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# get text embedding
|
||||
text_tokens_ = input_ids.clone()
|
||||
batch_size = input_ids.shape[0]
|
||||
# labels = text_tokens_[:, 1:].clone().contiguous()
|
||||
|
||||
media_token_indices = [
|
||||
# [:-1] since we would not use the last token for embedding
|
||||
get_media_indices(text_tokens_[i][:-1])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
text_tokens_[text_tokens_ < 0] = 1 # Not used
|
||||
# text_tokens = text_tokens_[:, :-1].contiguous()
|
||||
text_embeds = self.get_input_embeddings()(text_tokens_) # Temporally Embedding
|
||||
if hasattr(self.language_model, 'transformer') and hasattr(self.language_model.transformer, 'word_embeddings_layernorm'):
|
||||
text_embeds = self.language_model.transformer.word_embeddings_layernorm(text_embeds)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
|
||||
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
|
||||
query_features = self.abstractor(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
)["last_hidden_state"]
|
||||
torch.ones(query_features.size()[:-1], dtype=torch.long).to(query_features.device)
|
||||
img_seq_length = query_features.shape[1]
|
||||
|
||||
num_images_per_sample = num_images.long().cpu().tolist()
|
||||
|
||||
text_chunk_embeds = []
|
||||
img_idx = 0
|
||||
for b in range(batch_size):
|
||||
start = 0
|
||||
result = []
|
||||
if len(media_token_indices[b]) > 0:
|
||||
for i, pos in enumerate(media_token_indices[b]):
|
||||
if pos > start:
|
||||
result.append(text_embeds[b, start:pos])
|
||||
result.append(query_features[img_idx + i])
|
||||
start = pos + img_seq_length
|
||||
if start < text_embeds.shape[1]:
|
||||
result.append(text_embeds[b, start:])
|
||||
|
||||
img_idx += num_images_per_sample[b]
|
||||
text_chunk_embeds.append(torch.cat(result, dim=0))
|
||||
|
||||
# Actual Input Embeddings
|
||||
input_embeds = torch.stack(text_chunk_embeds, dim=0)
|
||||
|
||||
# Create causal mask and position ids
|
||||
_, loss_mask, position_ids = get_ltor_masks_and_position_ids_from_embeddings(input_embeds)
|
||||
|
||||
# Calculate the loss_mask
|
||||
non_padding_mask = non_padding_mask.long()
|
||||
non_media_mask = non_media_mask.long()
|
||||
prompt_mask = prompt_mask.long() # TODO How to deal with prompt mask
|
||||
# from icecream import ic
|
||||
# non_padding_mask = non_padding_mask[:,:-1]
|
||||
# non_media_mask = non_media_mask[:,:-1]
|
||||
# prompt_mask = prompt_mask[:,:-1]
|
||||
# attention_mask = attention_mask[:,:-1]
|
||||
loss_mask = loss_mask[:, :-1]
|
||||
|
||||
loss_mask = loss_mask * non_padding_mask * non_media_mask * prompt_mask
|
||||
labels[:, 1:][loss_mask != 1] = -100
|
||||
# Forward into GPT
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
return_dict=return_dict,
|
||||
output_attentions=self.config.output_attentions,
|
||||
)
|
||||
# outputs.loss = (outputs.loss * loss_mask.view(-1)
|
||||
# ).sum()/loss_mask.sum()
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
isdecoder=True,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Overrides `generate` function to be able to use the model as a conditional generator.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
|
||||
Input images to be processed.
|
||||
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices
|
||||
|
||||
Returns:
|
||||
captions (list): A list of strings of length batch_size * num_captions.
|
||||
"""
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(self.vision_model.embeddings.cls_token.data.dtype)
|
||||
|
||||
if input_ids is None:
|
||||
return self.language_model.generate(attention_mask=attention_mask, **generate_kwargs)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(*input_ids.shape)
|
||||
|
||||
batch_size = input_ids.size(0)
|
||||
media_token_indices = [get_media_indices(input_ids[i]) for i in range(batch_size)]
|
||||
num_images_per_sample = [len(x) for x in media_token_indices]
|
||||
input_ids = input_ids.clone() # prevent inplace modify
|
||||
input_ids[input_ids < 0] = 0 # Not used
|
||||
|
||||
if hasattr(self, "hf_device_map"):
|
||||
# preprocess for `accelerate`
|
||||
self._preprocess_accelerate()
|
||||
batch_size = input_ids.shape[0]
|
||||
# get text embedding
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if hasattr(self.language_model, 'transformer') and hasattr(self.language_model.transformer, 'word_embeddings_layernorm'):
|
||||
inputs_embeds = self.language_model.transformer.word_embeddings_layernorm(inputs_embeds)
|
||||
# get visual embedding
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(input_ids.device)
|
||||
with torch.no_grad():
|
||||
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
|
||||
image_attention_mask = torch.ones(
|
||||
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device
|
||||
)
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_outputs = self.abstractor(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs["last_hidden_state"]
|
||||
image_embeds = query_output
|
||||
img_seq_length = image_embeds.shape[1]
|
||||
|
||||
# ===================
|
||||
# Get actual input embeddings
|
||||
# ===================
|
||||
text_chunk_embeds = []
|
||||
text_chunk_attns = []
|
||||
img_idx = 0
|
||||
|
||||
for b in range(batch_size):
|
||||
start = 0
|
||||
result = []
|
||||
result_attn = []
|
||||
for i, pos in enumerate(media_token_indices[b]):
|
||||
if pos > start:
|
||||
result.append(inputs_embeds[b, start:pos])
|
||||
result_attn.append(attention_mask[b, start:pos])
|
||||
result.append(image_embeds[img_idx + i])
|
||||
result_attn.append(torch.ones(image_embeds[img_idx + i].shape[0], device=inputs_embeds.device))
|
||||
start = pos + img_seq_length
|
||||
if start < inputs_embeds.shape[1]:
|
||||
result.append(inputs_embeds[b, start:])
|
||||
result_attn.append(attention_mask[b, start:])
|
||||
|
||||
img_idx += num_images_per_sample[b]
|
||||
text_chunk_embeds.append(torch.cat(result, dim=0))
|
||||
text_chunk_attns.append(torch.cat(result_attn, dim=0))
|
||||
inputs_embeds = torch.stack(text_chunk_embeds, dim=0)
|
||||
attention_mask = torch.stack(text_chunk_attns, dim=0)
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
# input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, pixel_values=None, past_key_values=None, attention_mask=None, **model_kwargs
|
||||
):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# # cut decoder_input_ids if past_key_values is used
|
||||
# if past_key_values is not None:
|
||||
# input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"is_decoder": True,
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue