diff --git a/configs/blip2/blip2-opt2.7b_8xb32_caption.py b/configs/blip2/blip2-opt2.7b_8xb32_caption.py index 9fadd2fc..52d0a632 100644 --- a/configs/blip2/blip2-opt2.7b_8xb32_caption.py +++ b/configs/blip2/blip2-opt2.7b_8xb32_caption.py @@ -57,7 +57,7 @@ param_scheduler = [ ) ] -train_cfg = dict(max_epochs=10) +train_cfg = dict(by_epoch=True, max_epochs=10) val_cfg = dict() test_cfg = dict() diff --git a/mmpretrain/models/multimodal/blip2/Qformer.py b/mmpretrain/models/multimodal/blip2/Qformer.py index 2b85f9ee..4b1c7d1e 100644 --- a/mmpretrain/models/multimodal/blip2/Qformer.py +++ b/mmpretrain/models/multimodal/blip2/Qformer.py @@ -598,7 +598,8 @@ class BertLMHeadModel(BertPreTrainedModel): self.init_weights() def get_output_embeddings(self): - return self.cls.predictions.decoder + if self.cls is not None: + return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings diff --git a/mmpretrain/models/multimodal/blip2/blip2_caption.py b/mmpretrain/models/multimodal/blip2/blip2_caption.py index 7b409b07..acf69482 100644 --- a/mmpretrain/models/multimodal/blip2/blip2_caption.py +++ b/mmpretrain/models/multimodal/blip2/blip2_caption.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional +from typing import Dict, List, Optional import torch from mmengine.model import BaseModel @@ -93,14 +93,16 @@ class Blip2Caption(BaseModel): param.requires_grad = False if hasattr(self, 'register_load_state_dict_post_hook'): - self.register_load_state_dict_post_hook(self._ignore_llm_keys_hook) + self.register_load_state_dict_post_hook( + self._ignore_loading_llm_keys_hook) - def forward( - self, - images: torch.Tensor, - data_samples: Optional[List] = None, - mode: str = 'loss', - ) -> List[DataSample]: + if hasattr(self, '_register_state_dict_hook'): + self._register_state_dict_hook(self._igonre_saving_llm_keys_hook) + + def forward(self, + images: torch.Tensor, + data_samples: Optional[List] = None, + mode: str = 'loss'): """The unified entry for a forward process in both training and test. The method should accept two modes: "predict" and "loss": @@ -120,6 +122,8 @@ class Blip2Caption(BaseModel): Returns: The return type depends on ``mode``. - If ``mode="loss"``, return a dict of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. """ if mode == 'loss': return self.loss(images, data_samples) @@ -128,6 +132,85 @@ class Blip2Caption(BaseModel): else: raise RuntimeError(f'Invalid mode "{mode}".') + def loss(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``loss`` + method of :attr:`head`. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + # extract image features + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + self.tokenizer.padding_side = 'right' + + prompt = [ + self.prompt + data_sample.gt_caption + '\n' + for data_sample in data_samples + ] + + opt_tokens = self.tokenizer( + prompt, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + ).to(images.device) + + targets = opt_tokens.input_ids.masked_fill( + opt_tokens.input_ids == self.tokenizer.pad_token_id, -100) + if self.prompt: + targets[:, :self.prompt_length] = -100 + + empty_targets = ( + torch.ones(attns_opt.size(), + dtype=torch.long).to(images.device).fill_(-100)) + targets = torch.cat([empty_targets, targets], dim=1) + + inputs_embeds = ( + self.text_backbone.model.decoder.embed_tokens( + opt_tokens.input_ids)) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + outputs = self.text_backbone( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + return {'loss': loss} + def predict(self, images: torch.Tensor, data_samples: Optional[list] = None, @@ -146,7 +229,7 @@ class Blip2Caption(BaseModel): List[DataSample]: Return list of data samples. """ - # extract image features from + # extract image features image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) image_atts = torch.ones( image_embeds.size()[:-1], @@ -168,16 +251,21 @@ class Blip2Caption(BaseModel): prompt = [self.prompt] * image_embeds.size(0) opt_tokens = self.tokenizer( - prompt, return_tensors='pt').to(images.device) - input_ids = opt_tokens.input_ids + prompt, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + ).to(images.device) attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], dim=1) - query_embeds = inputs_opt + inputs_embeds = ( + self.text_backbone.get_input_embeddings()(opt_tokens.input_ids)) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) outputs = self.text_backbone.generate( - input_ids=input_ids, - query_embeds=query_embeds, + inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=False, top_p=0.9, @@ -192,7 +280,7 @@ class Blip2Caption(BaseModel): ) output_text = self.tokenizer.batch_decode( - outputs[:, self.prompt_length:], skip_special_tokens=True) + outputs, skip_special_tokens=True) output_text = [text.strip() for text in output_text] out_data_samples = [] @@ -208,10 +296,20 @@ class Blip2Caption(BaseModel): return out_data_samples @staticmethod - def _ignore_llm_keys_hook(module, incompatible_keys): + def _ignore_loading_llm_keys_hook(module, incompatible_keys): """Avoid warning missing keys of the LLM model.""" import re llm_pattern = '^text_backbone' for key in list(incompatible_keys.missing_keys): if re.match(llm_pattern, key): incompatible_keys.missing_keys.remove(key) + + @staticmethod + def _igonre_saving_llm_keys_hook(module, state_dict, prefix, metadata): + """Avoid saving llm state dict.""" + import re + llm_pattern = '^text_backbone' + keys = [k for k, _ in state_dict.items()] + for key in keys: + if re.match(llm_pattern, key): + state_dict.pop(key)