pytorch-image-models/timm/models/layers/evo_norm.py

82 lines
3.4 KiB
Python
Raw Normal View History

"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
An attempt at getting decent performing EvoNorms running in PyTorch.
While currently faster than other impl, still quite a ways off the built-in BN
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
2021-11-13 04:42:45 +08:00
from .trace_utils import _assert
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
super(EvoNormBatch2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.apply_act:
nn.init.ones_(self.v)
def forward(self, x):
2021-11-25 01:24:47 +08:00
_assert(x.dim() == 4, 'expected 4D input')
x_type = x.dtype
if self.v is not None:
2021-11-25 01:24:47 +08:00
running_var = self.running_var.view(1, -1, 1, 1)
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
self.running_var.copy_(running_var.view(self.running_var.shape))
else:
var = running_var
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
x = x / d
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
class EvoNormSample2d(nn.Module):
def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None):
super(EvoNormSample2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.groups = groups
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.apply_act:
nn.init.ones_(self.v)
def forward(self, x):
2021-11-13 04:42:45 +08:00
_assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape
2021-11-13 04:42:45 +08:00
_assert(C % self.groups == 0, '')
if self.v is not None:
n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
x = x.reshape(B, self.groups, -1)
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
x = x.reshape(B, C, H, W)
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)