Fix up a few details in NFResNet models, managed stable training. Add support for gamma gain to be applied in activation or ScaleStdConv. Some tweaks to ScaledStdConv.
parent
5a8e1e643e
commit
90980de4a9
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from .padding import get_padding
|
||||
from .conv2d_same import conv2d_same
|
||||
|
@ -69,20 +68,24 @@ class ScaledStdConv2d(nn.Conv2d):
|
|||
https://arxiv.org/abs/2101.08692
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size,
|
||||
stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
|
||||
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
|
||||
if padding is None:
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
|
||||
self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in)
|
||||
self.eps = eps
|
||||
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
|
||||
self.eps = eps ** 2 if use_layernorm else eps
|
||||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
|
||||
|
||||
def get_weight(self):
|
||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = (self.weight - mean) / (self.gamma * std + self.eps)
|
||||
if self.use_layernorm:
|
||||
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
||||
else:
|
||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = self.scale * (self.weight - mean) / (std + self.eps)
|
||||
if self.gain is not None:
|
||||
weight = weight * self.gain
|
||||
return weight
|
||||
|
|
|
@ -18,7 +18,7 @@ import torch.nn.functional as F
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible
|
||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn
|
||||
|
||||
|
||||
def _dcfg(url='', **kwargs):
|
||||
|
@ -40,17 +40,17 @@ default_cfgs = {
|
|||
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)),
|
||||
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)),
|
||||
|
||||
'nf_resnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_resnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_resnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_resnet26': _dcfg(url='', first_conv='stem.conv'),
|
||||
'nf_resnet50': _dcfg(url='', first_conv='stem.conv'),
|
||||
'nf_resnet101': _dcfg(url='', first_conv='stem.conv'),
|
||||
|
||||
'nf_seresnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_seresnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_seresnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'),
|
||||
'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'),
|
||||
'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'),
|
||||
|
||||
'nf_ecaresnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_ecaresnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_ecaresnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
||||
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'),
|
||||
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'),
|
||||
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'),
|
||||
}
|
||||
|
||||
|
||||
|
@ -59,6 +59,7 @@ class NfCfg:
|
|||
depths: Tuple[int, int, int, int]
|
||||
channels: Tuple[int, int, int, int]
|
||||
alpha: float = 0.2
|
||||
gamma_in_act: bool = False
|
||||
stem_type: str = '3x3'
|
||||
stem_chs: Optional[int] = None
|
||||
group_size: Optional[int] = 8
|
||||
|
@ -84,68 +85,65 @@ model_cfgs = dict(
|
|||
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
|
||||
|
||||
# ResNet (preact, D style deep stem/avg down) defs
|
||||
nf_resnet26d=NfCfg(
|
||||
nf_resnet26=NfCfg(
|
||||
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer=None,),
|
||||
nf_resnet50d=NfCfg(
|
||||
nf_resnet50=NfCfg(
|
||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer=None),
|
||||
nf_resnet101d=NfCfg(
|
||||
nf_resnet101=NfCfg(
|
||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer=None),
|
||||
|
||||
|
||||
nf_seresnet26d=NfCfg(
|
||||
nf_seresnet26=NfCfg(
|
||||
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||
nf_seresnet50d=NfCfg(
|
||||
nf_seresnet50=NfCfg(
|
||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||
nf_seresnet101d=NfCfg(
|
||||
nf_seresnet101=NfCfg(
|
||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||
|
||||
|
||||
nf_ecaresnet26d=NfCfg(
|
||||
nf_ecaresnet26=NfCfg(
|
||||
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||
nf_ecaresnet50d=NfCfg(
|
||||
nf_ecaresnet50=NfCfg(
|
||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||
nf_ecaresnet101d=NfCfg(
|
||||
nf_ecaresnet101=NfCfg(
|
||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||
|
||||
)
|
||||
|
||||
# class NormFreeSiLU(nn.Module):
|
||||
# _K = 1. / 0.5595
|
||||
# def __init__(self, inplace=False):
|
||||
# super().__init__()
|
||||
# self.inplace = inplace
|
||||
#
|
||||
# def forward(self, x):
|
||||
# return F.silu(x, inplace=self.inplace) * self._K
|
||||
#
|
||||
#
|
||||
# class NormFreeReLU(nn.Module):
|
||||
# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5
|
||||
#
|
||||
# def __init__(self, inplace=False):
|
||||
# super().__init__()
|
||||
# self.inplace = inplace
|
||||
#
|
||||
# def forward(self, x):
|
||||
# return F.relu(x, inplace=self.inplace) * self._K
|
||||
|
||||
class GammaAct(nn.Module):
|
||||
def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False):
|
||||
super().__init__()
|
||||
self.act_fn = get_act_fn(act_type)
|
||||
self.gamma = gamma
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return self.gamma * self.act_fn(x, inplace=self.inplace)
|
||||
|
||||
|
||||
def act_with_gamma(act_type, gamma: float = 1.):
|
||||
def _create(inplace=False):
|
||||
return GammaAct(act_type, gamma=gamma, inplace=inplace)
|
||||
return _create
|
||||
|
||||
|
||||
class DownsampleAvg(nn.Module):
|
||||
|
@ -178,10 +176,9 @@ class NormalizationFreeBlock(nn.Module):
|
|||
out_chs = out_chs or in_chs
|
||||
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
|
||||
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
|
||||
groups = 1
|
||||
if group_size is not None:
|
||||
# NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand.
|
||||
groups = mid_chs // group_size
|
||||
groups = 1 if group_size is None else mid_chs // group_size
|
||||
if group_size and group_size % ch_div == 0:
|
||||
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.attn_gain = attn_gain
|
||||
|
@ -229,10 +226,11 @@ class NormalizationFreeBlock(nn.Module):
|
|||
|
||||
|
||||
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
|
||||
stem_stride = 2
|
||||
stem = OrderedDict()
|
||||
assert stem_type in ('', 'deep', '3x3', '7x7')
|
||||
assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
|
||||
if 'deep' in stem_type:
|
||||
# 3 deep 3x3 conv stack as in ResNet V1D models
|
||||
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
|
||||
mid_chs = out_chs // 2
|
||||
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||
|
@ -244,12 +242,16 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
|
|||
# 7x7 stem conv as in ResNet
|
||||
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
||||
|
||||
return nn.Sequential(stem)
|
||||
if 'pool' in stem_type:
|
||||
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
stem_stride = 4
|
||||
|
||||
return nn.Sequential(stem), stem_stride
|
||||
|
||||
|
||||
_nonlin_gamma = dict(
|
||||
silu=.5595,
|
||||
relu=(0.5 * (1. - 1. / math.pi)) ** 0.5,
|
||||
silu=1./.5595,
|
||||
relu=(0.5 * (1. - 1. / math.pi)) ** -0.5,
|
||||
identity=1.0
|
||||
)
|
||||
|
||||
|
@ -264,9 +266,12 @@ class NormalizerFreeNet(nn.Module):
|
|||
the (preact) ResNet models described earlier in the paper.
|
||||
|
||||
There are a few differences:
|
||||
* channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts
|
||||
* channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),
|
||||
this changes channel dim and param counts slightly from the paper models
|
||||
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
|
||||
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
|
||||
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
|
||||
apply it in each activation. This is slightly slower, and yields slightly different results.
|
||||
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
|
||||
for what it is/does. Approx 8-10% throughput loss.
|
||||
"""
|
||||
|
@ -275,29 +280,33 @@ class NormalizerFreeNet(nn.Module):
|
|||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
act_layer = get_act_layer(cfg.act_layer)
|
||||
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
|
||||
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
|
||||
if cfg.gamma_in_act:
|
||||
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
|
||||
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True)
|
||||
else:
|
||||
act_layer = get_act_layer(cfg.act_layer)
|
||||
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
|
||||
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||
|
||||
self.feature_info = [] # FIXME fill out feature info
|
||||
|
||||
stem_chs = cfg.stem_chs or cfg.channels[0]
|
||||
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
|
||||
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
|
||||
self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
|
||||
|
||||
prev_chs = stem_chs
|
||||
self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
||||
net_stride = 2
|
||||
prev_chs = stem_chs
|
||||
net_stride = stem_stride
|
||||
dilation = 1
|
||||
expected_var = 1.0
|
||||
stages = []
|
||||
for stage_idx, stage_depth in enumerate(cfg.depths):
|
||||
if net_stride >= output_stride:
|
||||
dilation *= 2
|
||||
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
|
||||
self.feature_info += [dict(
|
||||
num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')]
|
||||
if net_stride >= output_stride and stride > 1:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
else:
|
||||
stride = 2
|
||||
net_stride *= stride
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
|
||||
|
@ -338,7 +347,10 @@ class NormalizerFreeNet(nn.Module):
|
|||
else:
|
||||
self.num_features = prev_chs
|
||||
self.final_conv = nn.Identity()
|
||||
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv
|
||||
self.final_act = act_layer()
|
||||
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
|
||||
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
|
@ -373,11 +385,14 @@ class NormalizerFreeNet(nn.Module):
|
|||
|
||||
|
||||
def _create_normfreenet(variant, pretrained=False, **kwargs):
|
||||
model_cfg = model_cfgs[variant]
|
||||
feature_cfg = dict(flatten_sequential=True)
|
||||
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
|
||||
if 'pool' in model_cfg.stem_type:
|
||||
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet
|
||||
|
||||
return build_model_with_cfg(
|
||||
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfgs[variant], default_cfg=default_cfgs[variant],
|
||||
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfg, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=feature_cfg, **kwargs)
|
||||
|
||||
|
||||
|
@ -412,30 +427,30 @@ def nf_regnet_b5(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def nf_resnet26d(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_resnet26d', pretrained=pretrained, **kwargs)
|
||||
def nf_resnet26(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def nf_resnet50d(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_resnet50d', pretrained=pretrained, **kwargs)
|
||||
def nf_resnet50(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def nf_seresnet26d(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_seresnet26d', pretrained=pretrained, **kwargs)
|
||||
def nf_seresnet26(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def nf_seresnet50d(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_seresnet50d', pretrained=pretrained, **kwargs)
|
||||
def nf_seresnet50(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def nf_ecaresnet26d(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_ecaresnet26d', pretrained=pretrained, **kwargs)
|
||||
def nf_ecaresnet26(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def nf_ecaresnet50d(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_ecaresnet50d', pretrained=pretrained, **kwargs)
|
||||
def nf_ecaresnet50(pretrained=False, **kwargs):
|
||||
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue