mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix spacing misalignment for fast norm path in LayerNorm modules
This commit is contained in:
parent
475ecdfa3d
commit
803254bb40
@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self._fast_norm:
|
||||
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x
|
||||
@ -65,7 +65,7 @@ class LayerNorm2d(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
if self._fast_norm:
|
||||
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user