add mplugowl model

pull/1775/head
qingtian5 2023-08-28 00:51:11 +08:00
parent fc2754f583
commit 16ac01aa88
1 changed files with 8 additions and 14 deletions

View File

@ -562,9 +562,9 @@ class MplugOwlVisualAbstractorAttention(BaseModel):
class MplugOwlVisualAbstractorLayer(BaseModel): class MplugOwlVisualAbstractorLayer(BaseModel):
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024): def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024,chunk_size_feed_forward=None):
super().__init__() super().__init__()
self.chunk_size_feed_forward = None self.chunk_size_feed_forward = chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -661,12 +661,11 @@ class MplugOwlVisualAbstractorEncoder(BaseModel):
class MplugOwlVisualAbstractorModel(BaseModel): class MplugOwlVisualAbstractorModel(BaseModel):
def __init__(self, config: MplugOwlVisualAbstractorConfig, language_hidden_size): def __init__(self, language_hidden_size, num_hidden_layers=6, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
super().__init__(config) super().__init__()
self.config = config
self.encoder = MplugOwlVisualAbstractorEncoder(config) self.encoder = MplugOwlVisualAbstractorEncoder(num_hidden_layers, hidden_size,num_attention_heads,intermediate_size,attention_probs_dropout_prob,layer_norm_eps,encoder_hidden_size)
self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size) self.visual_fc = torch.nn.Linear(hidden_size, language_hidden_size)
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size)) self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range) nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range)
self.post_init() self.post_init()
@ -824,8 +823,8 @@ class MplugOwlVisualAbstractorModel(BaseModel):
) )
class MplugOwlModel(MplugOwlPreTrainedModel): @MODELS.register_module()
config_class = MplugOwlConfig class MplugOwlModel(BaseModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
def __init__(self, config: MplugOwlConfig, *inputs, **kwargs): def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
@ -1080,9 +1079,4 @@ def bloom_forward(
attentions=all_self_attentions, attentions=all_self_attentions,
) )
@MODELS.register_module()
class mPLUGOwl(BaseModel):
def __init__(self,):
pass