Updated EvoNorm implementations with some experimentation. Add FilterResponseNorm. Updated RegnetZ and ResNetV2 model defs for trials.
parent
55adfbeb8d
commit
78912b6375
|
@ -35,7 +35,8 @@ import torch.nn as nn
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNormSample2d
|
||||
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
|
||||
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
|
||||
|
@ -152,6 +153,12 @@ default_cfgs = {
|
|||
'regnetz_e8': _cfgr(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0),
|
||||
|
||||
'regnetz_b16_evos': _cfgr(
|
||||
url='',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv',
|
||||
crop_pct=0.94),
|
||||
'regnetz_d8_evob': _cfgr(
|
||||
url='',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
|
||||
|
@ -597,6 +604,23 @@ model_cfgs = dict(
|
|||
),
|
||||
|
||||
# experimental EvoNorm configs
|
||||
regnetz_b16_evos=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
|
||||
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
|
||||
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
|
||||
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
|
||||
),
|
||||
stem_chs=32,
|
||||
stem_pool='',
|
||||
downsample='',
|
||||
num_features=1536,
|
||||
act_layer='silu',
|
||||
norm_layer=partial(EvoNorm2dS0a, group_size=16),
|
||||
attn_layer='se',
|
||||
attn_kwargs=dict(rd_ratio=0.25),
|
||||
block_kwargs=dict(bottle_in=True, linear_out=True),
|
||||
),
|
||||
regnetz_d8_evob=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
|
||||
|
@ -610,7 +634,7 @@ model_cfgs = dict(
|
|||
downsample='',
|
||||
num_features=1792,
|
||||
act_layer='silu',
|
||||
norm_layer='evonormbatch',
|
||||
norm_layer='evonormb0',
|
||||
attn_layer='se',
|
||||
attn_kwargs=dict(rd_ratio=0.25),
|
||||
block_kwargs=dict(bottle_in=True, linear_out=True),
|
||||
|
@ -628,7 +652,7 @@ model_cfgs = dict(
|
|||
downsample='',
|
||||
num_features=1792,
|
||||
act_layer='silu',
|
||||
norm_layer=partial(EvoNormSample2d, groups=32),
|
||||
norm_layer=partial(EvoNorm2dS0a, group_size=16),
|
||||
attn_layer='se',
|
||||
attn_kwargs=dict(rd_ratio=0.25),
|
||||
block_kwargs=dict(bottle_in=True, linear_out=True),
|
||||
|
@ -856,6 +880,13 @@ def regnetz_e8(pretrained=False, **kwargs):
|
|||
return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def regnetz_b16_evos(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('regnetz_b16_evos', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def regnetz_d8_evob(pretrained=False, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,8 @@ except ImportError:
|
|||
|
||||
# Layers we went to treat as leaf modules
|
||||
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
|
||||
from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2
|
||||
from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
|
||||
from .layers.non_local_attn import BilinearAttnTransform
|
||||
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
|
||||
|
@ -24,9 +26,12 @@ _leaf_modules = {
|
|||
BilinearAttnTransform, # reason: flow control t <= 1
|
||||
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
|
||||
# Reason: get_same_padding has a max which raises a control flow error
|
||||
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
||||
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
||||
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
||||
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
|
||||
EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?)
|
||||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a,
|
||||
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
|
@ -14,7 +14,9 @@ from .create_conv2d import create_conv2d
|
|||
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
|
||||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
|
||||
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
|
||||
|
|
|
@ -116,9 +116,6 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
|
|||
# custom autograd, then fallback
|
||||
if name in _ACT_FN_ME:
|
||||
return _ACT_FN_ME[name]
|
||||
if is_exportable() and name in ('silu', 'swish'):
|
||||
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
||||
return swish
|
||||
if not (is_no_jit() or is_exportable()):
|
||||
if name in _ACT_FN_JIT:
|
||||
return _ACT_FN_JIT[name]
|
||||
|
@ -132,14 +129,12 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
|
|||
"""
|
||||
if not name:
|
||||
return None
|
||||
if isinstance(name, type):
|
||||
if not isinstance(name, str):
|
||||
# callable, module, etc
|
||||
return name
|
||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
||||
if name in _ACT_LAYER_ME:
|
||||
return _ACT_LAYER_ME[name]
|
||||
if is_exportable() and name in ('silu', 'swish'):
|
||||
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
||||
return Swish
|
||||
if not (is_no_jit() or is_exportable()):
|
||||
if name in _ACT_LAYER_JIT:
|
||||
return _ACT_LAYER_JIT[name]
|
||||
|
|
|
@ -9,36 +9,42 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
import types
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .evo_norm import *
|
||||
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||
from .inplace_abn import InplaceAbn
|
||||
|
||||
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
|
||||
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type
|
||||
_NORM_ACT_MAP = dict(
|
||||
batchnorm=BatchNormAct2d,
|
||||
groupnorm=GroupNormAct,
|
||||
evonormb0=EvoNorm2dB0,
|
||||
evonormb1=EvoNorm2dB1,
|
||||
evonormb2=EvoNorm2dB2,
|
||||
evonorms0=EvoNorm2dS0,
|
||||
evonorms0a=EvoNorm2dS0a,
|
||||
evonorms1=EvoNorm2dS1,
|
||||
evonorms1a=EvoNorm2dS1a,
|
||||
evonorms2=EvoNorm2dS2,
|
||||
evonorms2a=EvoNorm2dS2a,
|
||||
frn=FilterResponseNormAct2d,
|
||||
frntlu=FilterResponseNormTlu2d,
|
||||
inplaceabn=InplaceAbn,
|
||||
iabn=InplaceAbn,
|
||||
)
|
||||
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
|
||||
# has act_layer arg to define act type
|
||||
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, FilterResponseNormAct2d, InplaceAbn}
|
||||
|
||||
|
||||
def get_norm_act_layer(layer_class):
|
||||
layer_class = layer_class.replace('_', '').lower()
|
||||
if layer_class.startswith("batchnorm"):
|
||||
layer = BatchNormAct2d
|
||||
elif layer_class.startswith("groupnorm"):
|
||||
layer = GroupNormAct
|
||||
elif layer_class == "evonormbatch":
|
||||
layer = EvoNormBatch2d
|
||||
elif layer_class == "evonormsample":
|
||||
layer = EvoNormSample2d
|
||||
elif layer_class == "iabn" or layer_class == "inplaceabn":
|
||||
layer = InplaceAbn
|
||||
else:
|
||||
assert False, "Invalid norm_act layer (%s)" % layer_class
|
||||
def get_norm_act_layer(layer_name):
|
||||
layer_name = layer_name.replace('_', '').lower().split('-')[0]
|
||||
layer = _NORM_ACT_MAP.get(layer_name, None)
|
||||
assert layer is not None, "Invalid norm_act layer (%s)" % layer_name
|
||||
return layer
|
||||
|
||||
|
||||
def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs):
|
||||
layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu
|
||||
def create_norm_act(layer_name, num_features, apply_act=True, jit=False, **kwargs):
|
||||
layer_parts = layer_name.split('-') # e.g. batchnorm-leaky_relu
|
||||
assert len(layer_parts) in (1, 2)
|
||||
layer = get_norm_act_layer(layer_parts[0])
|
||||
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
|
||||
|
|
|
@ -1,81 +1,332 @@
|
|||
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
|
||||
""" EvoNorm in PyTorch
|
||||
|
||||
Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967
|
||||
@inproceedings{NEURIPS2020,
|
||||
author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc},
|
||||
booktitle = {Advances in Neural Information Processing Systems},
|
||||
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
|
||||
pages = {13539--13550},
|
||||
publisher = {Curran Associates, Inc.},
|
||||
title = {Evolving Normalization-Activation Layers},
|
||||
url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf},
|
||||
volume = {33},
|
||||
year = {2020}
|
||||
}
|
||||
|
||||
An attempt at getting decent performing EvoNorms running in PyTorch.
|
||||
While currently faster than other impl, still quite a ways off the built-in BN
|
||||
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
|
||||
While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm
|
||||
in terms of memory usage and throughput on GPUs.
|
||||
|
||||
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
|
||||
I'm testing these modules on TPU w/ PyTorch XLA. Promising start but
|
||||
currently working around some issues with builtin torch/tensor.var/std. Unlike
|
||||
GPU, similar train speeds for EvoNormS variants and BatchNorm.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .create_act import create_act_layer
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
class EvoNormBatch2d(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
|
||||
super(EvoNormBatch2d, self).__init__()
|
||||
def instance_std(x, eps: float = 1e-5):
|
||||
rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype)
|
||||
return rms.expand(x.shape)
|
||||
|
||||
|
||||
def instance_rms(x, eps: float = 1e-5):
|
||||
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype)
|
||||
return rms.expand(x.shape)
|
||||
|
||||
|
||||
def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
|
||||
B, C, H, W = x.shape
|
||||
x_dtype = x.dtype
|
||||
_assert(C % groups == 0, '')
|
||||
# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
|
||||
# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt()
|
||||
x = x.reshape(B, groups, C // groups, H, W)
|
||||
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt()
|
||||
return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype)
|
||||
|
||||
|
||||
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False):
|
||||
# This is a workaround for some stability / odd behaviour of .var and .std
|
||||
# running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results
|
||||
B, C, H, W = x.shape
|
||||
_assert(C % groups == 0, '')
|
||||
x_dtype = x.dtype
|
||||
x = x.float().reshape(B, groups, C // groups, H, W)
|
||||
xm = x.mean(dim=(2, 3, 4), keepdim=True)
|
||||
if diff_sqm:
|
||||
# difference of squared mean and mean squared, faster on TPU
|
||||
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0)
|
||||
else:
|
||||
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True)
|
||||
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype)
|
||||
# group_std = group_std_tpu # temporary, for TPU / PT XLA
|
||||
|
||||
|
||||
def group_rms(x, groups: int = 32, eps: float = 1e-5):
|
||||
B, C, H, W = x.shape
|
||||
_assert(C % groups == 0, '')
|
||||
x_dtype = x.dtype
|
||||
x = x.reshape(B, groups, C // groups, H, W)
|
||||
sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
|
||||
return sqm.expand(x.shape).reshape(B, C, H, W)
|
||||
|
||||
|
||||
class EvoNorm2dB0(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
|
||||
super().__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
self.momentum = momentum
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
|
||||
self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
|
||||
self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
if self.apply_act:
|
||||
if self.v is not None:
|
||||
nn.init.ones_(self.v)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_type = x.dtype
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
if self.v is not None:
|
||||
running_var = self.running_var.view(1, -1, 1, 1)
|
||||
if self.training:
|
||||
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
||||
var = x.float().var(dim=(0, 2, 3), unbiased=False)
|
||||
n = x.numel() / x.shape[1]
|
||||
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
|
||||
self.running_var.copy_(running_var.view(self.running_var.shape))
|
||||
self.running_var.copy_(
|
||||
self.running_var * (1 - self.momentum) +
|
||||
var.detach() * self.momentum * (n / (n - 1)))
|
||||
else:
|
||||
var = running_var
|
||||
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
|
||||
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
|
||||
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
|
||||
x = x / d
|
||||
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
|
||||
var = self.running_var
|
||||
left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x)
|
||||
v = self.v.to(x_dtype).view(v_shape)
|
||||
right = x * v + instance_std(x, self.eps)
|
||||
x = x / left.max(right)
|
||||
return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape)
|
||||
|
||||
|
||||
class EvoNormSample2d(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None):
|
||||
super(EvoNormSample2d, self).__init__()
|
||||
class EvoNorm2dB1(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
|
||||
super().__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
self.groups = groups
|
||||
self.momentum = momentum
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
|
||||
self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
|
||||
self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
if self.apply_act:
|
||||
if self.training:
|
||||
var = x.float().var(dim=(0, 2, 3), unbiased=False)
|
||||
n = x.numel() / x.shape[1]
|
||||
self.running_var.copy_(
|
||||
self.running_var * (1 - self.momentum) +
|
||||
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
|
||||
else:
|
||||
var = self.running_var
|
||||
var = var.to(dtype=x_dtype).view(v_shape)
|
||||
left = var.add(self.eps).sqrt_()
|
||||
right = (x + 1) * instance_rms(x, self.eps)
|
||||
x = x / left.max(right)
|
||||
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
|
||||
|
||||
class EvoNorm2dB2(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
|
||||
super().__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
self.momentum = momentum
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
if self.apply_act:
|
||||
if self.training:
|
||||
var = x.float().var(dim=(0, 2, 3), unbiased=False)
|
||||
n = x.numel() / x.shape[1]
|
||||
self.running_var.copy_(
|
||||
self.running_var * (1 - self.momentum) +
|
||||
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
|
||||
else:
|
||||
var = self.running_var
|
||||
var = var.to(dtype=x_dtype).view(v_shape)
|
||||
left = var.add(self.eps).sqrt_()
|
||||
right = instance_rms(x, self.eps) - x
|
||||
x = x / left.max(right)
|
||||
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
|
||||
|
||||
class EvoNorm2dS0(nn.Module):
|
||||
def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
|
||||
super().__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
if group_size:
|
||||
assert num_features % group_size == 0
|
||||
self.groups = num_features // group_size
|
||||
else:
|
||||
self.groups = groups
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
if self.v is not None:
|
||||
nn.init.ones_(self.v)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
B, C, H, W = x.shape
|
||||
_assert(C % self.groups == 0, '')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
if self.v is not None:
|
||||
n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
|
||||
x = x.reshape(B, self.groups, -1)
|
||||
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
|
||||
x = x.reshape(B, C, H, W)
|
||||
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
|
||||
v = self.v.view(v_shape).to(dtype=x_dtype)
|
||||
x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
|
||||
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
|
||||
|
||||
class EvoNorm2dS0a(EvoNorm2dS0):
|
||||
def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
|
||||
super().__init__(
|
||||
num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
d = group_std(x, self.groups, self.eps)
|
||||
if self.v is not None:
|
||||
v = self.v.view(v_shape).to(dtype=x_dtype)
|
||||
x = x * (x * v).sigmoid_()
|
||||
x = x / d
|
||||
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
|
||||
|
||||
class EvoNorm2dS1(nn.Module):
|
||||
def __init__(
|
||||
self, num_features, groups=32, group_size=None,
|
||||
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
|
||||
super().__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
if act_layer is not None and apply_act:
|
||||
self.act = create_act_layer(act_layer)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
if group_size:
|
||||
assert num_features % group_size == 0
|
||||
self.groups = num_features // group_size
|
||||
else:
|
||||
self.groups = groups
|
||||
self.eps = eps
|
||||
self.pre_act_norm = False
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
if self.apply_act:
|
||||
x = self.act(x) / group_std(x, self.groups, self.eps)
|
||||
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
|
||||
|
||||
class EvoNorm2dS1a(EvoNorm2dS1):
|
||||
def __init__(
|
||||
self, num_features, groups=32, group_size=None,
|
||||
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
|
||||
super().__init__(
|
||||
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
x = self.act(x) / group_std(x, self.groups, self.eps)
|
||||
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
|
||||
|
||||
class EvoNorm2dS2(nn.Module):
|
||||
def __init__(
|
||||
self, num_features, groups=32, group_size=None,
|
||||
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
|
||||
super().__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
if act_layer is not None and apply_act:
|
||||
self.act = create_act_layer(act_layer)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
if group_size:
|
||||
assert num_features % group_size == 0
|
||||
self.groups = num_features // group_size
|
||||
else:
|
||||
self.groups = groups
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
if self.apply_act:
|
||||
x = self.act(x) / group_rms(x, self.groups, self.eps)
|
||||
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
|
||||
|
||||
class EvoNorm2dS2a(EvoNorm2dS2):
|
||||
def __init__(
|
||||
self, num_features, groups=32, group_size=None,
|
||||
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
|
||||
super().__init__(
|
||||
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
x = self.act(x) / group_rms(x, self.groups, self.eps)
|
||||
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
""" Filter Response Norm in PyTorch
|
||||
|
||||
Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .create_act import create_act_layer
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
def inv_instance_rms(x, eps: float = 1e-5):
|
||||
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
|
||||
return rms.expand(x.shape)
|
||||
|
||||
|
||||
class FilterResponseNormTlu2d(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
|
||||
super(FilterResponseNormTlu2d, self).__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
self.rms = rms
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
if self.tau is not None:
|
||||
nn.init.zeros_(self.tau)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
x = x * inv_instance_rms(x, self.eps)
|
||||
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
|
||||
|
||||
|
||||
class FilterResponseNormAct2d(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
|
||||
super(FilterResponseNormAct2d, self).__init__()
|
||||
if act_layer is not None and apply_act:
|
||||
self.act = create_act_layer(act_layer, inplace=inplace)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self.rms = rms
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
x_dtype = x.dtype
|
||||
v_shape = (1, -1, 1, 1)
|
||||
x = x * inv_instance_rms(x, self.eps)
|
||||
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
||||
return self.act(x)
|
|
@ -38,7 +38,8 @@ from functools import partial
|
|||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
|
||||
from .registry import register_model
|
||||
from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\
|
||||
from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0,\
|
||||
EvoNorm2dS1, EvoNorm2dS2, FilterResponseNormTlu2d, FilterResponseNormAct2d,\
|
||||
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
|
||||
|
||||
|
||||
|
@ -125,7 +126,11 @@ default_cfgs = {
|
|||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_50d_evob': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_50d_evos': _cfg(
|
||||
'resnetv2_50d_evos0': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_50d_evos1': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
'resnetv2_50d_frn': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
}
|
||||
|
||||
|
@ -660,13 +665,29 @@ def resnetv2_50d_gn(pretrained=False, **kwargs):
|
|||
def resnetv2_50d_evob(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50d_evob', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
|
||||
stem_type='deep', avg_down=True, zero_init_last=True, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50d_evos0(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50d_evos0', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
|
||||
stem_type='deep', avg_down=True, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50d_evos(pretrained=False, **kwargs):
|
||||
def resnetv2_50d_evos1(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50d_evos', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
|
||||
'resnetv2_50d_evos1', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=partial(EvoNorm2dS1, group_size=16),
|
||||
stem_type='deep', avg_down=True, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50d_frn(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50d_frn', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
|
||||
stem_type='deep', avg_down=True, **kwargs)
|
||||
|
|
|
@ -395,7 +395,7 @@ def eca_vovnet39b(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def ese_vovnet39b_evos(pretrained=False, **kwargs):
|
||||
def norm_act_fn(num_features, **nkwargs):
|
||||
return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs)
|
||||
return create_norm_act('evonorms0', num_features, jit=False, **nkwargs)
|
||||
return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue