[Enhancement] Support training of BLIP2 (#1700)
* [Fix] Fix BEiT pre_norm * [Enhancement] Support BLIP2 training * [Fix] Fix quoted strings * [Fix] Fix init_weights * [Fix] Fix with_cls_token * [Fix] Fix tokenizer * [Fix] Fix quoted strings * [Fix] Fix predict * [Fix] Cancel changing BEiT * [Fix] Add loading hook * [Fix] Reformat with yapf * [Fix] Fix prompt * [Fix] Fix typopull/1760/head
parent
fa53174fd9
commit
29d706248c
|
@ -57,7 +57,7 @@ param_scheduler = [
|
|||
)
|
||||
]
|
||||
|
||||
train_cfg = dict(max_epochs=10)
|
||||
train_cfg = dict(by_epoch=True, max_epochs=10)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
|
|
|
@ -598,7 +598,8 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
if self.cls is not None:
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModel
|
||||
|
@ -93,14 +93,16 @@ class Blip2Caption(BaseModel):
|
|||
param.requires_grad = False
|
||||
|
||||
if hasattr(self, 'register_load_state_dict_post_hook'):
|
||||
self.register_load_state_dict_post_hook(self._ignore_llm_keys_hook)
|
||||
self.register_load_state_dict_post_hook(
|
||||
self._ignore_loading_llm_keys_hook)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[List] = None,
|
||||
mode: str = 'loss',
|
||||
) -> List[DataSample]:
|
||||
if hasattr(self, '_register_state_dict_hook'):
|
||||
self._register_state_dict_hook(self._igonre_saving_llm_keys_hook)
|
||||
|
||||
def forward(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[List] = None,
|
||||
mode: str = 'loss'):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
The method should accept two modes: "predict" and "loss":
|
||||
|
||||
|
@ -120,6 +122,8 @@ class Blip2Caption(BaseModel):
|
|||
Returns:
|
||||
The return type depends on ``mode``.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
- If ``mode="predict"``, return a list of
|
||||
:obj:`mmpretrain.structures.DataSample`.
|
||||
"""
|
||||
if mode == 'loss':
|
||||
return self.loss(images, data_samples)
|
||||
|
@ -128,6 +132,85 @@ class Blip2Caption(BaseModel):
|
|||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
def loss(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[list] = None,
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
**kwargs: Other keyword arguments accepted by the ``loss``
|
||||
method of :attr:`head`.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
|
||||
# extract image features
|
||||
image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0])
|
||||
image_atts = torch.ones(
|
||||
image_embeds.size()[:-1],
|
||||
dtype=torch.long,
|
||||
).to(images.device)
|
||||
|
||||
# distill image features to query tokens
|
||||
query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1)
|
||||
query_outputs = self.multimodal_backbone.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
inputs_opt = self.vision_neck([query_outputs.last_hidden_state])
|
||||
attns_opt = torch.ones(
|
||||
inputs_opt.size()[:-1], dtype=torch.long).to(images.device)
|
||||
|
||||
self.tokenizer.padding_side = 'right'
|
||||
|
||||
prompt = [
|
||||
self.prompt + data_sample.gt_caption + '\n'
|
||||
for data_sample in data_samples
|
||||
]
|
||||
|
||||
opt_tokens = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
).to(images.device)
|
||||
|
||||
targets = opt_tokens.input_ids.masked_fill(
|
||||
opt_tokens.input_ids == self.tokenizer.pad_token_id, -100)
|
||||
if self.prompt:
|
||||
targets[:, :self.prompt_length] = -100
|
||||
|
||||
empty_targets = (
|
||||
torch.ones(attns_opt.size(),
|
||||
dtype=torch.long).to(images.device).fill_(-100))
|
||||
targets = torch.cat([empty_targets, targets], dim=1)
|
||||
|
||||
inputs_embeds = (
|
||||
self.text_backbone.model.decoder.embed_tokens(
|
||||
opt_tokens.input_ids))
|
||||
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
||||
attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask],
|
||||
dim=1)
|
||||
|
||||
outputs = self.text_backbone(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
labels=targets,
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
return {'loss': loss}
|
||||
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[list] = None,
|
||||
|
@ -146,7 +229,7 @@ class Blip2Caption(BaseModel):
|
|||
List[DataSample]: Return list of data samples.
|
||||
"""
|
||||
|
||||
# extract image features from
|
||||
# extract image features
|
||||
image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0])
|
||||
image_atts = torch.ones(
|
||||
image_embeds.size()[:-1],
|
||||
|
@ -168,16 +251,21 @@ class Blip2Caption(BaseModel):
|
|||
prompt = [self.prompt] * image_embeds.size(0)
|
||||
|
||||
opt_tokens = self.tokenizer(
|
||||
prompt, return_tensors='pt').to(images.device)
|
||||
input_ids = opt_tokens.input_ids
|
||||
prompt,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
).to(images.device)
|
||||
attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask],
|
||||
dim=1)
|
||||
|
||||
query_embeds = inputs_opt
|
||||
inputs_embeds = (
|
||||
self.text_backbone.get_input_embeddings()(opt_tokens.input_ids))
|
||||
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
||||
|
||||
outputs = self.text_backbone.generate(
|
||||
input_ids=input_ids,
|
||||
query_embeds=query_embeds,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
top_p=0.9,
|
||||
|
@ -192,7 +280,7 @@ class Blip2Caption(BaseModel):
|
|||
)
|
||||
|
||||
output_text = self.tokenizer.batch_decode(
|
||||
outputs[:, self.prompt_length:], skip_special_tokens=True)
|
||||
outputs, skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
|
||||
out_data_samples = []
|
||||
|
@ -208,10 +296,20 @@ class Blip2Caption(BaseModel):
|
|||
return out_data_samples
|
||||
|
||||
@staticmethod
|
||||
def _ignore_llm_keys_hook(module, incompatible_keys):
|
||||
def _ignore_loading_llm_keys_hook(module, incompatible_keys):
|
||||
"""Avoid warning missing keys of the LLM model."""
|
||||
import re
|
||||
llm_pattern = '^text_backbone'
|
||||
for key in list(incompatible_keys.missing_keys):
|
||||
if re.match(llm_pattern, key):
|
||||
incompatible_keys.missing_keys.remove(key)
|
||||
|
||||
@staticmethod
|
||||
def _igonre_saving_llm_keys_hook(module, state_dict, prefix, metadata):
|
||||
"""Avoid saving llm state dict."""
|
||||
import re
|
||||
llm_pattern = '^text_backbone'
|
||||
keys = [k for k, _ in state_dict.items()]
|
||||
for key in keys:
|
||||
if re.match(llm_pattern, key):
|
||||
state_dict.pop(key)
|
||||
|
|
Loading…
Reference in New Issue