add mplugowl model
parent
fc2754f583
commit
16ac01aa88
|
@ -562,9 +562,9 @@ class MplugOwlVisualAbstractorAttention(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MplugOwlVisualAbstractorLayer(BaseModel):
|
class MplugOwlVisualAbstractorLayer(BaseModel):
|
||||||
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
|
def __init__(self,layer_idx, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024,chunk_size_feed_forward=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size_feed_forward = None
|
self.chunk_size_feed_forward = chunk_size_feed_forward
|
||||||
self.seq_len_dim = 1
|
self.seq_len_dim = 1
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
@ -661,12 +661,11 @@ class MplugOwlVisualAbstractorEncoder(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MplugOwlVisualAbstractorModel(BaseModel):
|
class MplugOwlVisualAbstractorModel(BaseModel):
|
||||||
def __init__(self, config: MplugOwlVisualAbstractorConfig, language_hidden_size):
|
def __init__(self, language_hidden_size, num_hidden_layers=6, hidden_size=1024,num_attention_heads=16,intermediate_size=4096,attention_probs_dropout_prob=0.1,layer_norm_eps=1e-6,encoder_hidden_size=1024):
|
||||||
super().__init__(config)
|
super().__init__()
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.encoder = MplugOwlVisualAbstractorEncoder(config)
|
self.encoder = MplugOwlVisualAbstractorEncoder(num_hidden_layers, hidden_size,num_attention_heads,intermediate_size,attention_probs_dropout_prob,layer_norm_eps,encoder_hidden_size)
|
||||||
self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size)
|
self.visual_fc = torch.nn.Linear(hidden_size, language_hidden_size)
|
||||||
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
|
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
|
||||||
nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range)
|
nn.init.trunc_normal_(self.vit_eos, mean=0.0, std=self.config.initializer_range)
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
@ -824,8 +823,8 @@ class MplugOwlVisualAbstractorModel(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MplugOwlModel(MplugOwlPreTrainedModel):
|
@MODELS.register_module()
|
||||||
config_class = MplugOwlConfig
|
class MplugOwlModel(BaseModel):
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
|
|
||||||
def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
|
def __init__(self, config: MplugOwlConfig, *inputs, **kwargs):
|
||||||
|
@ -1080,9 +1079,4 @@ def bloom_forward(
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class mPLUGOwl(BaseModel):
|
|
||||||
def __init__(self,):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue