From 29d706248c17f77da093e5bb342e667fdf2ce0a9 Mon Sep 17 00:00:00 2001
From: fanqiNO1 <75657629+fanqiNO1@users.noreply.github.com>
Date: Thu, 10 Aug 2023 11:15:38 +0800
Subject: [PATCH] [Enhancement] Support training of BLIP2 (#1700)

* [Fix] Fix BEiT pre_norm

* [Enhancement] Support BLIP2 training

* [Fix] Fix quoted strings

* [Fix] Fix init_weights

* [Fix] Fix with_cls_token

* [Fix] Fix tokenizer

* [Fix] Fix quoted strings

* [Fix] Fix predict

* [Fix] Cancel changing BEiT

* [Fix] Add loading hook

* [Fix] Reformat with yapf

* [Fix] Fix prompt

* [Fix] Fix typo
---
 configs/blip2/blip2-opt2.7b_8xb32_caption.py  |   2 +-
 mmpretrain/models/multimodal/blip2/Qformer.py |   3 +-
 .../models/multimodal/blip2/blip2_caption.py  | 130 +++++++++++++++---
 3 files changed, 117 insertions(+), 18 deletions(-)

diff --git a/configs/blip2/blip2-opt2.7b_8xb32_caption.py b/configs/blip2/blip2-opt2.7b_8xb32_caption.py
index 9fadd2fc..52d0a632 100644
--- a/configs/blip2/blip2-opt2.7b_8xb32_caption.py
+++ b/configs/blip2/blip2-opt2.7b_8xb32_caption.py
@@ -57,7 +57,7 @@ param_scheduler = [
     )
 ]
 
-train_cfg = dict(max_epochs=10)
+train_cfg = dict(by_epoch=True, max_epochs=10)
 val_cfg = dict()
 test_cfg = dict()
 
diff --git a/mmpretrain/models/multimodal/blip2/Qformer.py b/mmpretrain/models/multimodal/blip2/Qformer.py
index 2b85f9ee..4b1c7d1e 100644
--- a/mmpretrain/models/multimodal/blip2/Qformer.py
+++ b/mmpretrain/models/multimodal/blip2/Qformer.py
@@ -598,7 +598,8 @@ class BertLMHeadModel(BertPreTrainedModel):
         self.init_weights()
 
     def get_output_embeddings(self):
-        return self.cls.predictions.decoder
+        if self.cls is not None:
+            return self.cls.predictions.decoder
 
     def set_output_embeddings(self, new_embeddings):
         self.cls.predictions.decoder = new_embeddings
diff --git a/mmpretrain/models/multimodal/blip2/blip2_caption.py b/mmpretrain/models/multimodal/blip2/blip2_caption.py
index 7b409b07..acf69482 100644
--- a/mmpretrain/models/multimodal/blip2/blip2_caption.py
+++ b/mmpretrain/models/multimodal/blip2/blip2_caption.py
@@ -1,5 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from typing import List, Optional
+from typing import Dict, List, Optional
 
 import torch
 from mmengine.model import BaseModel
@@ -93,14 +93,16 @@ class Blip2Caption(BaseModel):
             param.requires_grad = False
 
         if hasattr(self, 'register_load_state_dict_post_hook'):
-            self.register_load_state_dict_post_hook(self._ignore_llm_keys_hook)
+            self.register_load_state_dict_post_hook(
+                self._ignore_loading_llm_keys_hook)
 
-    def forward(
-        self,
-        images: torch.Tensor,
-        data_samples: Optional[List] = None,
-        mode: str = 'loss',
-    ) -> List[DataSample]:
+        if hasattr(self, '_register_state_dict_hook'):
+            self._register_state_dict_hook(self._igonre_saving_llm_keys_hook)
+
+    def forward(self,
+                images: torch.Tensor,
+                data_samples: Optional[List] = None,
+                mode: str = 'loss'):
         """The unified entry for a forward process in both training and test.
         The method should accept two modes: "predict" and "loss":
 
@@ -120,6 +122,8 @@ class Blip2Caption(BaseModel):
         Returns:
             The return type depends on ``mode``.
             - If ``mode="loss"``, return a dict of tensor.
+            - If ``mode="predict"``, return a list of
+              :obj:`mmpretrain.structures.DataSample`.
         """
         if mode == 'loss':
             return self.loss(images, data_samples)
@@ -128,6 +132,85 @@ class Blip2Caption(BaseModel):
         else:
             raise RuntimeError(f'Invalid mode "{mode}".')
 
+    def loss(self,
+             images: torch.Tensor,
+             data_samples: Optional[list] = None,
+             **kwargs) -> Dict[str, torch.Tensor]:
+        """The forward function in training.
+
+        Args:
+            images (torch.Tensor): The input tensor with shape
+                (N, C, ...) in general.
+            data_samples (List[DataSample], optional): The annotation
+                data of every samples. Defaults to None.
+            **kwargs: Other keyword arguments accepted by the ``loss``
+                method of :attr:`head`.
+
+        Returns:
+            Dict[str, torch.Tensor]: A dictionary of loss components.
+        """
+
+        # extract image features
+        image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0])
+        image_atts = torch.ones(
+            image_embeds.size()[:-1],
+            dtype=torch.long,
+        ).to(images.device)
+
+        # distill image features to query tokens
+        query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1)
+        query_outputs = self.multimodal_backbone.bert(
+            query_embeds=query_tokens,
+            encoder_hidden_states=image_embeds,
+            encoder_attention_mask=image_atts,
+            return_dict=True,
+        )
+        inputs_opt = self.vision_neck([query_outputs.last_hidden_state])
+        attns_opt = torch.ones(
+            inputs_opt.size()[:-1], dtype=torch.long).to(images.device)
+
+        self.tokenizer.padding_side = 'right'
+
+        prompt = [
+            self.prompt + data_sample.gt_caption + '\n'
+            for data_sample in data_samples
+        ]
+
+        opt_tokens = self.tokenizer(
+            prompt,
+            return_tensors='pt',
+            padding='longest',
+            truncation=True,
+            max_length=self.max_txt_len,
+        ).to(images.device)
+
+        targets = opt_tokens.input_ids.masked_fill(
+            opt_tokens.input_ids == self.tokenizer.pad_token_id, -100)
+        if self.prompt:
+            targets[:, :self.prompt_length] = -100
+
+        empty_targets = (
+            torch.ones(attns_opt.size(),
+                       dtype=torch.long).to(images.device).fill_(-100))
+        targets = torch.cat([empty_targets, targets], dim=1)
+
+        inputs_embeds = (
+            self.text_backbone.model.decoder.embed_tokens(
+                opt_tokens.input_ids))
+        inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
+        attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask],
+                                   dim=1)
+
+        outputs = self.text_backbone(
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            return_dict=True,
+            labels=targets,
+        )
+        loss = outputs.loss
+
+        return {'loss': loss}
+
     def predict(self,
                 images: torch.Tensor,
                 data_samples: Optional[list] = None,
@@ -146,7 +229,7 @@ class Blip2Caption(BaseModel):
             List[DataSample]: Return list of data samples.
         """
 
-        # extract image features from
+        # extract image features
         image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0])
         image_atts = torch.ones(
             image_embeds.size()[:-1],
@@ -168,16 +251,21 @@ class Blip2Caption(BaseModel):
         prompt = [self.prompt] * image_embeds.size(0)
 
         opt_tokens = self.tokenizer(
-            prompt, return_tensors='pt').to(images.device)
-        input_ids = opt_tokens.input_ids
+            prompt,
+            return_tensors='pt',
+            padding='longest',
+            truncation=True,
+            max_length=self.max_txt_len,
+        ).to(images.device)
         attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask],
                                    dim=1)
 
-        query_embeds = inputs_opt
+        inputs_embeds = (
+            self.text_backbone.get_input_embeddings()(opt_tokens.input_ids))
+        inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
 
         outputs = self.text_backbone.generate(
-            input_ids=input_ids,
-            query_embeds=query_embeds,
+            inputs_embeds=inputs_embeds,
             attention_mask=attention_mask,
             do_sample=False,
             top_p=0.9,
@@ -192,7 +280,7 @@ class Blip2Caption(BaseModel):
         )
 
         output_text = self.tokenizer.batch_decode(
-            outputs[:, self.prompt_length:], skip_special_tokens=True)
+            outputs, skip_special_tokens=True)
         output_text = [text.strip() for text in output_text]
 
         out_data_samples = []
@@ -208,10 +296,20 @@ class Blip2Caption(BaseModel):
         return out_data_samples
 
     @staticmethod
-    def _ignore_llm_keys_hook(module, incompatible_keys):
+    def _ignore_loading_llm_keys_hook(module, incompatible_keys):
         """Avoid warning missing keys of the LLM model."""
         import re
         llm_pattern = '^text_backbone'
         for key in list(incompatible_keys.missing_keys):
             if re.match(llm_pattern, key):
                 incompatible_keys.missing_keys.remove(key)
+
+    @staticmethod
+    def _igonre_saving_llm_keys_hook(module, state_dict, prefix, metadata):
+        """Avoid saving llm state dict."""
+        import re
+        llm_pattern = '^text_backbone'
+        keys = [k for k, _ in state_dict.items()]
+        for key in keys:
+            if re.match(llm_pattern, key):
+                state_dict.pop(key)