[Enhance] Set 'is_init' in some multimodal methods (#1718)
* update is_init of multimodal * Update minigpt4.py --------- Co-authored-by: Ma Zerun <mzr1996@163.com>pull/1670/head
parent
e7fc25cf64
commit
b1cd05caf2
|
@ -96,6 +96,7 @@ class Flamingo(BaseModel):
|
||||||
map_location='cpu',
|
map_location='cpu',
|
||||||
revise_keys=[(r'^backbone\.', '')],
|
revise_keys=[(r'^backbone\.', '')],
|
||||||
)
|
)
|
||||||
|
self.vision_encoder.is_init = True
|
||||||
|
|
||||||
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)
|
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)
|
||||||
|
|
||||||
|
|
|
@ -94,6 +94,7 @@ class Llava(BaseModel):
|
||||||
map_location='cpu',
|
map_location='cpu',
|
||||||
revise_keys=[(r'^backbone\.', '')],
|
revise_keys=[(r'^backbone\.', '')],
|
||||||
)
|
)
|
||||||
|
vision_encoder.is_init = True
|
||||||
|
|
||||||
# init language encoder related modules
|
# init language encoder related modules
|
||||||
if load_lang_pretrained:
|
if load_lang_pretrained:
|
||||||
|
|
|
@ -79,6 +79,7 @@ class MiniGPT4(BaseModel):
|
||||||
if vision_encoder_weight is not None:
|
if vision_encoder_weight is not None:
|
||||||
from mmengine.runner.checkpoint import load_checkpoint
|
from mmengine.runner.checkpoint import load_checkpoint
|
||||||
load_checkpoint(self.vision_encoder, vision_encoder_weight)
|
load_checkpoint(self.vision_encoder, vision_encoder_weight)
|
||||||
|
self.vision_encoder.is_init = True
|
||||||
if freeze_vit:
|
if freeze_vit:
|
||||||
for name, param in self.ln_vision.named_parameters():
|
for name, param in self.ln_vision.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
@ -108,6 +109,9 @@ class MiniGPT4(BaseModel):
|
||||||
state_dict = CheckpointLoader.load_checkpoint(
|
state_dict = CheckpointLoader.load_checkpoint(
|
||||||
q_former_model_weight)['state_dict']
|
q_former_model_weight)['state_dict']
|
||||||
self.load_state_dict(state_dict, strict=False)
|
self.load_state_dict(state_dict, strict=False)
|
||||||
|
# The ln_vision weights are also in the q-former checkpoint.
|
||||||
|
setattr(self.ln_vision, 'is_init', True)
|
||||||
|
setattr(self.q_former, 'is_init', True)
|
||||||
|
|
||||||
if freeze_q_former:
|
if freeze_q_former:
|
||||||
for name, param in self.q_former.named_parameters():
|
for name, param in self.q_former.named_parameters():
|
||||||
|
|
|
@ -98,6 +98,7 @@ class Otter(Flamingo):
|
||||||
map_location='cpu',
|
map_location='cpu',
|
||||||
revise_keys=[(r'^backbone\.', '')],
|
revise_keys=[(r'^backbone\.', '')],
|
||||||
)
|
)
|
||||||
|
self.vision_encoder.is_init = True
|
||||||
|
|
||||||
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)
|
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,9 @@ def register_hf_model(
|
||||||
kwargs.pop('name_or_path'))
|
kwargs.pop('name_or_path'))
|
||||||
|
|
||||||
if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model:
|
if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model:
|
||||||
return cls.from_pretrained(name_or_path, **kwargs)
|
model = cls.from_pretrained(name_or_path, **kwargs)
|
||||||
|
setattr(model, 'is_init', True)
|
||||||
|
return model
|
||||||
else:
|
else:
|
||||||
cfg = get_config(name_or_path, **kwargs)
|
cfg = get_config(name_or_path, **kwargs)
|
||||||
return from_config(cfg)
|
return from_config(cfg)
|
||||||
|
|
Loading…
Reference in New Issue