[Enhancement] Support training of BLIP2 (#1700)

* [Fix] Fix BEiT pre_norm

* [Enhancement] Support BLIP2 training

* [Fix] Fix quoted strings

* [Fix] Fix init_weights

* [Fix] Fix with_cls_token

* [Fix] Fix tokenizer

* [Fix] Fix quoted strings

* [Fix] Fix predict

* [Fix] Cancel changing BEiT

* [Fix] Add loading hook

* [Fix] Reformat with yapf

* [Fix] Fix prompt

* [Fix] Fix typo
pull/1760/head
fanqiNO1 2023-08-10 11:15:38 +08:00 committed by GitHub
parent fa53174fd9
commit 29d706248c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 117 additions and 18 deletions

View File

@ -57,7 +57,7 @@ param_scheduler = [
)
]
train_cfg = dict(max_epochs=10)
train_cfg = dict(by_epoch=True, max_epochs=10)
val_cfg = dict()
test_cfg = dict()

View File

@ -598,7 +598,8 @@ class BertLMHeadModel(BertPreTrainedModel):
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
if self.cls is not None:
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
from typing import Dict, List, Optional
import torch
from mmengine.model import BaseModel
@ -93,14 +93,16 @@ class Blip2Caption(BaseModel):
param.requires_grad = False
if hasattr(self, 'register_load_state_dict_post_hook'):
self.register_load_state_dict_post_hook(self._ignore_llm_keys_hook)
self.register_load_state_dict_post_hook(
self._ignore_loading_llm_keys_hook)
def forward(
self,
images: torch.Tensor,
data_samples: Optional[List] = None,
mode: str = 'loss',
) -> List[DataSample]:
if hasattr(self, '_register_state_dict_hook'):
self._register_state_dict_hook(self._igonre_saving_llm_keys_hook)
def forward(self,
images: torch.Tensor,
data_samples: Optional[List] = None,
mode: str = 'loss'):
"""The unified entry for a forward process in both training and test.
The method should accept two modes: "predict" and "loss":
@ -120,6 +122,8 @@ class Blip2Caption(BaseModel):
Returns:
The return type depends on ``mode``.
- If ``mode="loss"``, return a dict of tensor.
- If ``mode="predict"``, return a list of
:obj:`mmpretrain.structures.DataSample`.
"""
if mode == 'loss':
return self.loss(images, data_samples)
@ -128,6 +132,85 @@ class Blip2Caption(BaseModel):
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def loss(self,
images: torch.Tensor,
data_samples: Optional[list] = None,
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
images (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``loss``
method of :attr:`head`.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# extract image features
image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0])
image_atts = torch.ones(
image_embeds.size()[:-1],
dtype=torch.long,
).to(images.device)
# distill image features to query tokens
query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1)
query_outputs = self.multimodal_backbone.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_opt = self.vision_neck([query_outputs.last_hidden_state])
attns_opt = torch.ones(
inputs_opt.size()[:-1], dtype=torch.long).to(images.device)
self.tokenizer.padding_side = 'right'
prompt = [
self.prompt + data_sample.gt_caption + '\n'
for data_sample in data_samples
]
opt_tokens = self.tokenizer(
prompt,
return_tensors='pt',
padding='longest',
truncation=True,
max_length=self.max_txt_len,
).to(images.device)
targets = opt_tokens.input_ids.masked_fill(
opt_tokens.input_ids == self.tokenizer.pad_token_id, -100)
if self.prompt:
targets[:, :self.prompt_length] = -100
empty_targets = (
torch.ones(attns_opt.size(),
dtype=torch.long).to(images.device).fill_(-100))
targets = torch.cat([empty_targets, targets], dim=1)
inputs_embeds = (
self.text_backbone.model.decoder.embed_tokens(
opt_tokens.input_ids))
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask],
dim=1)
outputs = self.text_backbone(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
)
loss = outputs.loss
return {'loss': loss}
def predict(self,
images: torch.Tensor,
data_samples: Optional[list] = None,
@ -146,7 +229,7 @@ class Blip2Caption(BaseModel):
List[DataSample]: Return list of data samples.
"""
# extract image features from
# extract image features
image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0])
image_atts = torch.ones(
image_embeds.size()[:-1],
@ -168,16 +251,21 @@ class Blip2Caption(BaseModel):
prompt = [self.prompt] * image_embeds.size(0)
opt_tokens = self.tokenizer(
prompt, return_tensors='pt').to(images.device)
input_ids = opt_tokens.input_ids
prompt,
return_tensors='pt',
padding='longest',
truncation=True,
max_length=self.max_txt_len,
).to(images.device)
attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask],
dim=1)
query_embeds = inputs_opt
inputs_embeds = (
self.text_backbone.get_input_embeddings()(opt_tokens.input_ids))
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
outputs = self.text_backbone.generate(
input_ids=input_ids,
query_embeds=query_embeds,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
do_sample=False,
top_p=0.9,
@ -192,7 +280,7 @@ class Blip2Caption(BaseModel):
)
output_text = self.tokenizer.batch_decode(
outputs[:, self.prompt_length:], skip_special_tokens=True)
outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
out_data_samples = []
@ -208,10 +296,20 @@ class Blip2Caption(BaseModel):
return out_data_samples
@staticmethod
def _ignore_llm_keys_hook(module, incompatible_keys):
def _ignore_loading_llm_keys_hook(module, incompatible_keys):
"""Avoid warning missing keys of the LLM model."""
import re
llm_pattern = '^text_backbone'
for key in list(incompatible_keys.missing_keys):
if re.match(llm_pattern, key):
incompatible_keys.missing_keys.remove(key)
@staticmethod
def _igonre_saving_llm_keys_hook(module, state_dict, prefix, metadata):
"""Avoid saving llm state dict."""
import re
llm_pattern = '^text_backbone'
keys = [k for k, _ in state_dict.items()]
for key in keys:
if re.match(llm_pattern, key):
state_dict.pop(key)