add model file

pull/1775/head
qingtian5 2023-08-28 12:16:24 +08:00
parent 16ac01aa88
commit 4260679d33
4 changed files with 401 additions and 142 deletions

View File

View File

@ -0,0 +1,78 @@
_base_ = [
'../_base_/datasets/coco_caption.py',
'../_base_/default_runtime.py',
]
# dataset settings
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# model settings
model = dict(
type='MplugOwlForConditionalGeneration',
vision_encoder=dict(
type='MplugOwlVisionModel',
hidden_size=1024,
image_size=224,
patch_size=14,
intermediate_size=4096,
num_attention_heads=16,
attention_dropout=0.0,
layer_norm_eps=1e-6,
num_hidden_layers=24,
pretrained= # noqa
'' # noqa
),
abstractor_model=dict(
type='MplugOwlVisualAbstractorModel',
language_hidden_size=4096,
num_hidden_layers=6,
hidden_size=1024,
num_attention_heads=16,
intermediate_size=4096,
attention_probs_dropout_prob=0.1,
layer_norm_eps=1e-6,
encoder_hidden_size=1024,
pretrained= # noqa
'' # noqa
),
lang_encoder=dict(
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_LLAMA'),
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_LLAMA'),
task='caption',
prompt_template="The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\nHuman: <image>\nHuman: how many cats are there?\nAI: ",
# raw_prompts=[
# '<Img><ImageHere></Img> Describe this image in detail.',
# '<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
# '<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
# '<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa
# ],
# max_txt_len=160,
# end_sym='###'
)
# schedule settings
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))
param_scheduler = [
dict(
type='CosineAnnealingLR',
by_epoch=True,
begin=0,
end=5,
)
]
train_cfg = dict(by_epoch=True, max_epochs=5)
val_cfg = dict()
test_cfg = dict()

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mplug_owl import mPLUGOwl
from .mplug_owl import MplugOwl
__all__ = ['mPLUGOwl']
__all__ = ['MplugOwl']

View File

@ -165,7 +165,6 @@ class MplugOwlVisionAttention(BaseModel):
return outputs
QuickGLUE = MODELS.bulid(dict(type="QuickGELU"))
class MplugOwlMLP(BaseModel):
@ -318,8 +317,8 @@ class MplugOwlVisionEncoder(BaseModel):
)
@MODELS.register_module()
class MplugOwlVisionModel(BaseModel):
main_input_name = "pixel_values"
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, intermediate_size=4096, num_attention_heads=16, attention_dropout=0.0,layer_norm_eps=1e-6, num_hidden_layers=24):
super().__init__()
@ -660,10 +659,12 @@ class MplugOwlVisualAbstractorEncoder(BaseModel):
)
@MODELS.register_module()
class MplugOwlVisualAbstractorModel(BaseModel):
def __init__(self, language_hidden_size, num_hidden_layers=6, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
super().__init__()
self.language_hidden_size = language_hidden_size
self.encoder = MplugOwlVisualAbstractorEncoder(num_hidden_layers, hidden_size,num_attention_heads,intermediate_size,attention_probs_dropout_prob,layer_norm_eps,encoder_hidden_size)
self.visual_fc = torch.nn.Linear(hidden_size, language_hidden_size)
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
@ -816,35 +817,32 @@ class MplugOwlVisualAbstractorModel(BaseModel):
sequence_output = self.visual_fc(sequence_output)
sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(sequence_output.shape[0], 1, 1)], dim=1)
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
return (
sequence_output,
pooled_output,
encoder_outputs.hidden_states,
)
@MODELS.register_module()
class MplugOwlModel(BaseModel):
main_input_name = "pixel_values"
class MplugOwl(BaseModel):
def __init__(self, vision_encoder, abstractor_model, lang_encoder, num_query_tokens=64):
super().__init__()
def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.vision_model = MODELS.build(vision_encoder)
self.vision_model = MplugOwlVisionModel(config.vision_config)
self.abstractor = MODELS.build(abstractor_model)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.visual_abstractor_config.hidden_size)
)
self.abstractor = MplugOwlVisualAbstractorModel(
config.visual_abstractor_config, config.text_config.hidden_size
torch.zeros(1, num_query_tokens, self.abstractor.language_hidden_size)
)
# if config.use_decoder_only_language_model:
# from llama.modeling_llama import LlamaForCausalLM
language_model = AutoModelForCausalLM.from_config(config.text_config)
# language_model = AutoModelForCausalLM.from_config(config.text_config)
# else:
# language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
self.language_model = language_model
self.language_model = MODELS.build(lang_encoder)
# Initialize weights and apply final processing
self.post_init()
@ -939,144 +937,327 @@ class MplugOwlModel(BaseModel):
return vision_outputs
# Hack for bloomz
def bloom_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
class MplugOwlForConditionalGeneration(BaseModel):
def __init__(self, vision_encoder, abstractor_model, lang_encoder,lang_tokenizer,num_query_tokens=64):
super().__init__()
self.vision_model = MODELS.build(vision_encoder)
self.abstractor = MODELS.build(abstractor_model)
self.query_tokens = nn.Parameter(
torch.zeros(1, num_query_tokens, self.abstractor.language_hidden_size)
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.language_model = MODELS.build(lang_encoder)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
self.language_tokenizer = TOKENIZER.build(lang_tokenizer)
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Initialize weights and apply final processing
self.post_init()
self.main_input_name = "input_ids"
from transformers import GenerationConfig
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
self.generation_config = GenerationConfig(
max_length=512, do_sample=True, top_k=3, pad_token_id=0, unk_token_id=0, bos_token_id=1, eos_token_id=2
)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
inputs_embeds = self.word_embeddings_layernorm(inputs_embeds)
hidden_states = inputs_embeds
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
def get_output_embeddings(self) -> nn.Module:
return self.language_model.get_output_embeddings()
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
def get_encoder(self):
return self.language_model.get_encoder()
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
def get_decoder(self):
return self.language_model.get_decoder()
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
if self.gradient_checkpointing and self.training:
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
https://github.com/huggingface/transformers/pull/21707 for more details.
"""
hf_device_map = self.hf_device_map
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
# warn users about unexpected behavior when using multi-GPU + mPLUG-Owl + `accelerate`.
MMLogger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
" Please pass a `device_map` that contains `language_model` to remove this warning."
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
" more details on creating a `device_map` for large models.",
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.FloatTensor,
num_images,
non_padding_mask: Optional[torch.LongTensor] = None,
non_media_mask: Optional[torch.LongTensor] = None,
prompt_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
):
r"""
Returns:
SFT example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import MplugOwlProcessor, MplugOwlForConditionalGeneration
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> processor = MplugOwlProcessor.from_pretrained("MAGAer13/mplug-owl-llama-7b")
>>> model = MplugOwlForConditionalGeneration.from_pretrained(
... "MAGAer13/mplug-owl-llama-7b", torch_dtype=torch.float16
... )
>>> model.to(device) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = [
... "The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\nHuman: <image>\nHuman: how many cats are there?\nAI: "
... ]
>>> inputs = processor(images=[image], text=prompt, return_tensors="pt").to(device, torch.float16)
>>> generated_ids = model.generate(**inputs)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
>>> print(generated_text)
There are two cats in the image.
```"""
if pixel_values is not None:
pixel_values = pixel_values.to(self.vision_model.embeddings.cls_token.data.dtype)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# get text embedding
text_tokens_ = input_ids.clone()
batch_size = input_ids.shape[0]
# labels = text_tokens_[:, 1:].clone().contiguous()
media_token_indices = [
# [:-1] since we would not use the last token for embedding
get_media_indices(text_tokens_[i][:-1])
for i in range(batch_size)
]
text_tokens_[text_tokens_ < 0] = 1 # Not used
# text_tokens = text_tokens_[:, :-1].contiguous()
text_embeds = self.get_input_embeddings()(text_tokens_) # Temporally Embedding
if hasattr(self.language_model, 'transformer') and hasattr(self.language_model.transformer, 'word_embeddings_layernorm'):
text_embeds = self.language_model.transformer.word_embeddings_layernorm(text_embeds)
if pixel_values is not None:
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_features = self.abstractor(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
)["last_hidden_state"]
torch.ones(query_features.size()[:-1], dtype=torch.long).to(query_features.device)
img_seq_length = query_features.shape[1]
num_images_per_sample = num_images.long().cpu().tolist()
text_chunk_embeds = []
img_idx = 0
for b in range(batch_size):
start = 0
result = []
if len(media_token_indices[b]) > 0:
for i, pos in enumerate(media_token_indices[b]):
if pos > start:
result.append(text_embeds[b, start:pos])
result.append(query_features[img_idx + i])
start = pos + img_seq_length
if start < text_embeds.shape[1]:
result.append(text_embeds[b, start:])
img_idx += num_images_per_sample[b]
text_chunk_embeds.append(torch.cat(result, dim=0))
# Actual Input Embeddings
input_embeds = torch.stack(text_chunk_embeds, dim=0)
# Create causal mask and position ids
_, loss_mask, position_ids = get_ltor_masks_and_position_ids_from_embeddings(input_embeds)
# Calculate the loss_mask
non_padding_mask = non_padding_mask.long()
non_media_mask = non_media_mask.long()
prompt_mask = prompt_mask.long() # TODO How to deal with prompt mask
# from icecream import ic
# non_padding_mask = non_padding_mask[:,:-1]
# non_media_mask = non_media_mask[:,:-1]
# prompt_mask = prompt_mask[:,:-1]
# attention_mask = attention_mask[:,:-1]
loss_mask = loss_mask[:, :-1]
loss_mask = loss_mask * non_padding_mask * non_media_mask * prompt_mask
labels[:, 1:][loss_mask != 1] = -100
# Forward into GPT
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
labels=labels,
return_dict=return_dict,
output_attentions=self.config.output_attentions,
)
# outputs.loss = (outputs.loss * loss_mask.view(-1)
# ).sum()/loss_mask.sum()
return outputs
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
isdecoder=True,
**generate_kwargs,
) -> torch.LongTensor:
"""
Overrides `generate` function to be able to use the model as a conditional generator.
Args:
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
Input images to be processed.
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices
Returns:
captions (list): A list of strings of length batch_size * num_captions.
"""
if pixel_values is not None:
pixel_values = pixel_values.to(self.vision_model.embeddings.cls_token.data.dtype)
if input_ids is None:
return self.language_model.generate(attention_mask=attention_mask, **generate_kwargs)
if attention_mask is None:
attention_mask = input_ids.new_ones(*input_ids.shape)
batch_size = input_ids.size(0)
media_token_indices = [get_media_indices(input_ids[i]) for i in range(batch_size)]
num_images_per_sample = [len(x) for x in media_token_indices]
input_ids = input_ids.clone() # prevent inplace modify
input_ids[input_ids < 0] = 0 # Not used
if hasattr(self, "hf_device_map"):
# preprocess for `accelerate`
self._preprocess_accelerate()
batch_size = input_ids.shape[0]
# get text embedding
inputs_embeds = self.get_input_embeddings()(input_ids)
if hasattr(self.language_model, 'transformer') and hasattr(self.language_model.transformer, 'word_embeddings_layernorm'):
inputs_embeds = self.language_model.transformer.word_embeddings_layernorm(inputs_embeds)
# get visual embedding
if pixel_values is not None:
pixel_values = pixel_values.to(input_ids.device)
with torch.no_grad():
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_attention_mask = torch.ones(
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device
)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.abstractor(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs["last_hidden_state"]
image_embeds = query_output
img_seq_length = image_embeds.shape[1]
# ===================
# Get actual input embeddings
# ===================
text_chunk_embeds = []
text_chunk_attns = []
img_idx = 0
for b in range(batch_size):
start = 0
result = []
result_attn = []
for i, pos in enumerate(media_token_indices[b]):
if pos > start:
result.append(inputs_embeds[b, start:pos])
result_attn.append(attention_mask[b, start:pos])
result.append(image_embeds[img_idx + i])
result_attn.append(torch.ones(image_embeds[img_idx + i].shape[0], device=inputs_embeds.device))
start = pos + img_seq_length
if start < inputs_embeds.shape[1]:
result.append(inputs_embeds[b, start:])
result_attn.append(attention_mask[b, start:])
img_idx += num_images_per_sample[b]
text_chunk_embeds.append(torch.cat(result, dim=0))
text_chunk_attns.append(torch.cat(result_attn, dim=0))
inputs_embeds = torch.stack(text_chunk_embeds, dim=0)
attention_mask = torch.stack(text_chunk_attns, dim=0)
outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
# input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
return outputs
def prepare_inputs_for_generation(
self, input_ids, pixel_values=None, past_key_values=None, attention_mask=None, **model_kwargs
):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# # cut decoder_input_ids if past_key_values is used
# if past_key_values is not None:
# input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"attention_mask": attention_mask,
"is_decoder": True,
}