[Enhance] Set 'is_init' in some multimodal methods (#1718)
* update is_init of multimodal * Update minigpt4.py --------- Co-authored-by: Ma Zerun <mzr1996@163.com>pull/1670/head
parent
e7fc25cf64
commit
b1cd05caf2
mmpretrain/models
multimodal
utils
|
@ -96,6 +96,7 @@ class Flamingo(BaseModel):
|
|||
map_location='cpu',
|
||||
revise_keys=[(r'^backbone\.', '')],
|
||||
)
|
||||
self.vision_encoder.is_init = True
|
||||
|
||||
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)
|
||||
|
||||
|
|
|
@ -94,6 +94,7 @@ class Llava(BaseModel):
|
|||
map_location='cpu',
|
||||
revise_keys=[(r'^backbone\.', '')],
|
||||
)
|
||||
vision_encoder.is_init = True
|
||||
|
||||
# init language encoder related modules
|
||||
if load_lang_pretrained:
|
||||
|
|
|
@ -79,6 +79,7 @@ class MiniGPT4(BaseModel):
|
|||
if vision_encoder_weight is not None:
|
||||
from mmengine.runner.checkpoint import load_checkpoint
|
||||
load_checkpoint(self.vision_encoder, vision_encoder_weight)
|
||||
self.vision_encoder.is_init = True
|
||||
if freeze_vit:
|
||||
for name, param in self.ln_vision.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
@ -108,6 +109,9 @@ class MiniGPT4(BaseModel):
|
|||
state_dict = CheckpointLoader.load_checkpoint(
|
||||
q_former_model_weight)['state_dict']
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
# The ln_vision weights are also in the q-former checkpoint.
|
||||
setattr(self.ln_vision, 'is_init', True)
|
||||
setattr(self.q_former, 'is_init', True)
|
||||
|
||||
if freeze_q_former:
|
||||
for name, param in self.q_former.named_parameters():
|
||||
|
|
|
@ -98,6 +98,7 @@ class Otter(Flamingo):
|
|||
map_location='cpu',
|
||||
revise_keys=[(r'^backbone\.', '')],
|
||||
)
|
||||
self.vision_encoder.is_init = True
|
||||
|
||||
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)
|
||||
|
||||
|
|
|
@ -86,7 +86,9 @@ def register_hf_model(
|
|||
kwargs.pop('name_or_path'))
|
||||
|
||||
if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model:
|
||||
return cls.from_pretrained(name_or_path, **kwargs)
|
||||
model = cls.from_pretrained(name_or_path, **kwargs)
|
||||
setattr(model, 'is_init', True)
|
||||
return model
|
||||
else:
|
||||
cfg = get_config(name_or_path, **kwargs)
|
||||
return from_config(cfg)
|
||||
|
|
Loading…
Reference in New Issue