add mplugowl model
parent
fc2754f583
commit
16ac01aa88
|
@ -562,9 +562,9 @@ class MplugOwlVisualAbstractorAttention(BaseModel):
|
|||
|
||||
|
||||
class MplugOwlVisualAbstractorLayer(BaseModel):
|
||||
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
|
||||
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024,chunk_size_feed_forward=None):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = None
|
||||
self.chunk_size_feed_forward = chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
@ -661,12 +661,11 @@ class MplugOwlVisualAbstractorEncoder(BaseModel):
|
|||
|
||||
|
||||
class MplugOwlVisualAbstractorModel(BaseModel):
|
||||
def __init__(self, config: MplugOwlVisualAbstractorConfig, language_hidden_size):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
def __init__(self, language_hidden_size, num_hidden_layers=6, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = MplugOwlVisualAbstractorEncoder(config)
|
||||
self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size)
|
||||
self.encoder = MplugOwlVisualAbstractorEncoder(num_hidden_layers, hidden_size,num_attention_heads,intermediate_size,attention_probs_dropout_prob,layer_norm_eps,encoder_hidden_size)
|
||||
self.visual_fc = torch.nn.Linear(hidden_size, language_hidden_size)
|
||||
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
|
||||
nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range)
|
||||
self.post_init()
|
||||
|
@ -824,8 +823,8 @@ class MplugOwlVisualAbstractorModel(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class MplugOwlModel(MplugOwlPreTrainedModel):
|
||||
config_class = MplugOwlConfig
|
||||
@MODELS.register_module()
|
||||
class MplugOwlModel(BaseModel):
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
|
||||
|
@ -1080,9 +1079,4 @@ def bloom_forward(
|
|||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
@MODELS.register_module()
|
||||
class mPLUGOwl(BaseModel):
|
||||
def __init__(self,):
|
||||
pass
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue