mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update metaformers.py
This commit is contained in:
parent
d17bcb10a5
commit
01f671ed08
@ -347,18 +347,16 @@ class LayerNormGeneral(nn.Module):
|
|||||||
self.normalized_dim = normalized_dim
|
self.normalized_dim = normalized_dim
|
||||||
self.use_scale = scale
|
self.use_scale = scale
|
||||||
self.use_bias = bias
|
self.use_bias = bias
|
||||||
self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None
|
self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else torch.ones(affine_shape)
|
||||||
self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None
|
self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else torch.zeros(affine_shape)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
c = x - x.mean(self.normalized_dim, keepdim=True)
|
c = x - x.mean(self.normalized_dim, keepdim=True)
|
||||||
s = c.pow(2).mean(self.normalized_dim, keepdim=True)
|
s = c.pow(2).mean(self.normalized_dim, keepdim=True)
|
||||||
x = c / torch.sqrt(s + self.eps)
|
x = c / torch.sqrt(s + self.eps)
|
||||||
if self.use_scale:
|
x = x * self.weight
|
||||||
x = x * self.weight
|
x = x + self.bias
|
||||||
if self.use_bias:
|
|
||||||
x = x + self.bias
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user