Fix FX breaking assert in evonorm
parent
f83b0b01e3
commit
480c676ffa
|
@ -34,18 +34,17 @@ class EvoNormBatch2d(nn.Module):
|
|||
nn.init.ones_(self.v)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'expected 4D input'
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_type = x.dtype
|
||||
running_var = self.running_var.view(1, -1, 1, 1)
|
||||
if self.training:
|
||||
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
||||
n = x.numel() / x.shape[1]
|
||||
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
|
||||
self.running_var.copy_(running_var.view(self.running_var.shape))
|
||||
else:
|
||||
var = running_var
|
||||
|
||||
if self.v is not None:
|
||||
running_var = self.running_var.view(1, -1, 1, 1)
|
||||
if self.training:
|
||||
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
||||
n = x.numel() / x.shape[1]
|
||||
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
|
||||
self.running_var.copy_(running_var.view(self.running_var.shape))
|
||||
else:
|
||||
var = running_var
|
||||
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
|
||||
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
|
||||
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
|
||||
|
|
Loading…
Reference in New Issue