[Enhance] Set 'is_init' in some multimodal methods ()

* update is_init of multimodal

* Update minigpt4.py

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
pull/1670/head
Yixiao Fang 2023-07-28 15:28:07 +08:00 committed by GitHub
parent e7fc25cf64
commit b1cd05caf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 10 additions and 1 deletions
mmpretrain/models
multimodal
flamingo
minigpt4

View File

@ -96,6 +96,7 @@ class Flamingo(BaseModel):
map_location='cpu',
revise_keys=[(r'^backbone\.', '')],
)
self.vision_encoder.is_init = True
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)

View File

@ -94,6 +94,7 @@ class Llava(BaseModel):
map_location='cpu',
revise_keys=[(r'^backbone\.', '')],
)
vision_encoder.is_init = True
# init language encoder related modules
if load_lang_pretrained:

View File

@ -79,6 +79,7 @@ class MiniGPT4(BaseModel):
if vision_encoder_weight is not None:
from mmengine.runner.checkpoint import load_checkpoint
load_checkpoint(self.vision_encoder, vision_encoder_weight)
self.vision_encoder.is_init = True
if freeze_vit:
for name, param in self.ln_vision.named_parameters():
param.requires_grad = False
@ -108,6 +109,9 @@ class MiniGPT4(BaseModel):
state_dict = CheckpointLoader.load_checkpoint(
q_former_model_weight)['state_dict']
self.load_state_dict(state_dict, strict=False)
# The ln_vision weights are also in the q-former checkpoint.
setattr(self.ln_vision, 'is_init', True)
setattr(self.q_former, 'is_init', True)
if freeze_q_former:
for name, param in self.q_former.named_parameters():

View File

@ -98,6 +98,7 @@ class Otter(Flamingo):
map_location='cpu',
revise_keys=[(r'^backbone\.', '')],
)
self.vision_encoder.is_init = True
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)

View File

@ -86,7 +86,9 @@ def register_hf_model(
kwargs.pop('name_or_path'))
if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model:
return cls.from_pretrained(name_or_path, **kwargs)
model = cls.from_pretrained(name_or_path, **kwargs)
setattr(model, 'is_init', True)
return model
else:
cfg = get_config(name_or_path, **kwargs)
return from_config(cfg)