mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix FX breaking assert in evonorm
This commit is contained in:
parent
f83b0b01e3
commit
480c676ffa
@ -34,8 +34,9 @@ class EvoNormBatch2d(nn.Module):
|
||||
nn.init.ones_(self.v)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'expected 4D input'
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_type = x.dtype
|
||||
if self.v is not None:
|
||||
running_var = self.running_var.view(1, -1, 1, 1)
|
||||
if self.training:
|
||||
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
||||
@ -44,8 +45,6 @@ class EvoNormBatch2d(nn.Module):
|
||||
self.running_var.copy_(running_var.view(self.running_var.shape))
|
||||
else:
|
||||
var = running_var
|
||||
|
||||
if self.v is not None:
|
||||
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
|
||||
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
|
||||
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
|
||||
|
Loading…
x
Reference in New Issue
Block a user