From b1cd05caf28540e79ebca5cdc3029712286b4fc3 Mon Sep 17 00:00:00 2001 From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> Date: Fri, 28 Jul 2023 15:28:07 +0800 Subject: [PATCH] [Enhance] Set 'is_init' in some multimodal methods (#1718) * update is_init of multimodal * Update minigpt4.py --------- Co-authored-by: Ma Zerun --- mmpretrain/models/multimodal/flamingo/flamingo.py | 1 + mmpretrain/models/multimodal/llava/llava.py | 1 + mmpretrain/models/multimodal/minigpt4/minigpt4.py | 4 ++++ mmpretrain/models/multimodal/otter/otter.py | 1 + mmpretrain/models/utils/huggingface.py | 4 +++- 5 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mmpretrain/models/multimodal/flamingo/flamingo.py b/mmpretrain/models/multimodal/flamingo/flamingo.py index abdd0332..039f9ff3 100644 --- a/mmpretrain/models/multimodal/flamingo/flamingo.py +++ b/mmpretrain/models/multimodal/flamingo/flamingo.py @@ -96,6 +96,7 @@ class Flamingo(BaseModel): map_location='cpu', revise_keys=[(r'^backbone\.', '')], ) + self.vision_encoder.is_init = True self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims) diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py index 1c300fdc..103d8129 100644 --- a/mmpretrain/models/multimodal/llava/llava.py +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -94,6 +94,7 @@ class Llava(BaseModel): map_location='cpu', revise_keys=[(r'^backbone\.', '')], ) + vision_encoder.is_init = True # init language encoder related modules if load_lang_pretrained: diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py index d2320360..4bbd5aaa 100644 --- a/mmpretrain/models/multimodal/minigpt4/minigpt4.py +++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py @@ -79,6 +79,7 @@ class MiniGPT4(BaseModel): if vision_encoder_weight is not None: from mmengine.runner.checkpoint import load_checkpoint load_checkpoint(self.vision_encoder, vision_encoder_weight) + self.vision_encoder.is_init = True if freeze_vit: for name, param in self.ln_vision.named_parameters(): param.requires_grad = False @@ -108,6 +109,9 @@ class MiniGPT4(BaseModel): state_dict = CheckpointLoader.load_checkpoint( q_former_model_weight)['state_dict'] self.load_state_dict(state_dict, strict=False) + # The ln_vision weights are also in the q-former checkpoint. + setattr(self.ln_vision, 'is_init', True) + setattr(self.q_former, 'is_init', True) if freeze_q_former: for name, param in self.q_former.named_parameters(): diff --git a/mmpretrain/models/multimodal/otter/otter.py b/mmpretrain/models/multimodal/otter/otter.py index 2fed1a4d..5065c58c 100644 --- a/mmpretrain/models/multimodal/otter/otter.py +++ b/mmpretrain/models/multimodal/otter/otter.py @@ -98,6 +98,7 @@ class Otter(Flamingo): map_location='cpu', revise_keys=[(r'^backbone\.', '')], ) + self.vision_encoder.is_init = True self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims) diff --git a/mmpretrain/models/utils/huggingface.py b/mmpretrain/models/utils/huggingface.py index e527315b..a44d6daa 100644 --- a/mmpretrain/models/utils/huggingface.py +++ b/mmpretrain/models/utils/huggingface.py @@ -86,7 +86,9 @@ def register_hf_model( kwargs.pop('name_or_path')) if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model: - return cls.from_pretrained(name_or_path, **kwargs) + model = cls.from_pretrained(name_or_path, **kwargs) + setattr(model, 'is_init', True) + return model else: cfg = get_config(name_or_path, **kwargs) return from_config(cfg)