add mplugowl model

pull/1775/head
qingtian5 2023-08-28 00:51:11 +08:00
parent fc2754f583
commit 16ac01aa88
1 changed files with 8 additions and 14 deletions
mmpretrain/models/multimodal/mplugowl

View File

@ -562,9 +562,9 @@ class MplugOwlVisualAbstractorAttention(BaseModel):
class MplugOwlVisualAbstractorLayer(BaseModel):
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024,chunk_size_feed_forward=None):
super().__init__()
self.chunk_size_feed_forward = None
self.chunk_size_feed_forward = chunk_size_feed_forward
self.seq_len_dim = 1
self.layer_idx = layer_idx
@ -661,12 +661,11 @@ class MplugOwlVisualAbstractorEncoder(BaseModel):
class MplugOwlVisualAbstractorModel(BaseModel):
def __init__(self, config: MplugOwlVisualAbstractorConfig, language_hidden_size):
super().__init__(config)
self.config = config
def __init__(self, language_hidden_size, num_hidden_layers=6, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
super().__init__()
self.encoder = MplugOwlVisualAbstractorEncoder(config)
self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size)
self.encoder = MplugOwlVisualAbstractorEncoder(num_hidden_layers, hidden_size,num_attention_heads,intermediate_size,attention_probs_dropout_prob,layer_norm_eps,encoder_hidden_size)
self.visual_fc = torch.nn.Linear(hidden_size, language_hidden_size)
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range)
self.post_init()
@ -824,8 +823,8 @@ class MplugOwlVisualAbstractorModel(BaseModel):
)
class MplugOwlModel(MplugOwlPreTrainedModel):
config_class = MplugOwlConfig
@MODELS.register_module()
class MplugOwlModel(BaseModel):
main_input_name = "pixel_values"
def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
@ -1080,9 +1079,4 @@ def bloom_forward(
attentions=all_self_attentions,
)
@MODELS.register_module()
class mPLUGOwl(BaseModel):
def __init__(self,):
pass