mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix up a few details in NFResNet models, managed stable training. Add support for gamma gain to be applied in activation or ScaleStdConv. Some tweaks to ScaledStdConv.
This commit is contained in:
parent
5a8e1e643e
commit
90980de4a9
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .padding import get_padding
|
from .padding import get_padding
|
||||||
from .conv2d_same import conv2d_same
|
from .conv2d_same import conv2d_same
|
||||||
@ -69,20 +68,24 @@ class ScaledStdConv2d(nn.Conv2d):
|
|||||||
https://arxiv.org/abs/2101.08692
|
https://arxiv.org/abs/2101.08692
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size,
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
|
||||||
stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5):
|
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
|
||||||
if padding is None:
|
if padding is None:
|
||||||
padding = get_padding(kernel_size, stride, dilation)
|
padding = get_padding(kernel_size, stride, dilation)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_channels, out_channels, kernel_size, stride=stride,
|
in_channels, out_channels, kernel_size, stride=stride,
|
||||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||||
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
|
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
|
||||||
self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in)
|
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
|
||||||
self.eps = eps
|
self.eps = eps ** 2 if use_layernorm else eps
|
||||||
|
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
|
if self.use_layernorm:
|
||||||
|
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
||||||
|
else:
|
||||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||||
weight = (self.weight - mean) / (self.gamma * std + self.eps)
|
weight = self.scale * (self.weight - mean) / (std + self.eps)
|
||||||
if self.gain is not None:
|
if self.gain is not None:
|
||||||
weight = weight * self.gain
|
weight = weight * self.gain
|
||||||
return weight
|
return weight
|
||||||
|
@ -18,7 +18,7 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible
|
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn
|
||||||
|
|
||||||
|
|
||||||
def _dcfg(url='', **kwargs):
|
def _dcfg(url='', **kwargs):
|
||||||
@ -40,17 +40,17 @@ default_cfgs = {
|
|||||||
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)),
|
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)),
|
||||||
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)),
|
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)),
|
||||||
|
|
||||||
'nf_resnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_resnet26': _dcfg(url='', first_conv='stem.conv'),
|
||||||
'nf_resnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_resnet50': _dcfg(url='', first_conv='stem.conv'),
|
||||||
'nf_resnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_resnet101': _dcfg(url='', first_conv='stem.conv'),
|
||||||
|
|
||||||
'nf_seresnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'),
|
||||||
'nf_seresnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'),
|
||||||
'nf_seresnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'),
|
||||||
|
|
||||||
'nf_ecaresnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'),
|
||||||
'nf_ecaresnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'),
|
||||||
'nf_ecaresnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -59,6 +59,7 @@ class NfCfg:
|
|||||||
depths: Tuple[int, int, int, int]
|
depths: Tuple[int, int, int, int]
|
||||||
channels: Tuple[int, int, int, int]
|
channels: Tuple[int, int, int, int]
|
||||||
alpha: float = 0.2
|
alpha: float = 0.2
|
||||||
|
gamma_in_act: bool = False
|
||||||
stem_type: str = '3x3'
|
stem_type: str = '3x3'
|
||||||
stem_chs: Optional[int] = None
|
stem_chs: Optional[int] = None
|
||||||
group_size: Optional[int] = 8
|
group_size: Optional[int] = 8
|
||||||
@ -84,68 +85,65 @@ model_cfgs = dict(
|
|||||||
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
|
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
|
||||||
|
|
||||||
# ResNet (preact, D style deep stem/avg down) defs
|
# ResNet (preact, D style deep stem/avg down) defs
|
||||||
nf_resnet26d=NfCfg(
|
nf_resnet26=NfCfg(
|
||||||
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer=None,),
|
act_layer='relu', attn_layer=None,),
|
||||||
nf_resnet50d=NfCfg(
|
nf_resnet50=NfCfg(
|
||||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer=None),
|
act_layer='relu', attn_layer=None),
|
||||||
nf_resnet101d=NfCfg(
|
nf_resnet101=NfCfg(
|
||||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer=None),
|
act_layer='relu', attn_layer=None),
|
||||||
|
|
||||||
|
|
||||||
nf_seresnet26d=NfCfg(
|
nf_seresnet26=NfCfg(
|
||||||
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||||
nf_seresnet50d=NfCfg(
|
nf_seresnet50=NfCfg(
|
||||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||||
nf_seresnet101d=NfCfg(
|
nf_seresnet101=NfCfg(
|
||||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||||
|
|
||||||
|
|
||||||
nf_ecaresnet26d=NfCfg(
|
nf_ecaresnet26=NfCfg(
|
||||||
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||||
nf_ecaresnet50d=NfCfg(
|
nf_ecaresnet50=NfCfg(
|
||||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||||
nf_ecaresnet101d=NfCfg(
|
nf_ecaresnet101=NfCfg(
|
||||||
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# class NormFreeSiLU(nn.Module):
|
|
||||||
# _K = 1. / 0.5595
|
class GammaAct(nn.Module):
|
||||||
# def __init__(self, inplace=False):
|
def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False):
|
||||||
# super().__init__()
|
super().__init__()
|
||||||
# self.inplace = inplace
|
self.act_fn = get_act_fn(act_type)
|
||||||
#
|
self.gamma = gamma
|
||||||
# def forward(self, x):
|
self.inplace = inplace
|
||||||
# return F.silu(x, inplace=self.inplace) * self._K
|
|
||||||
#
|
def forward(self, x):
|
||||||
#
|
return self.gamma * self.act_fn(x, inplace=self.inplace)
|
||||||
# class NormFreeReLU(nn.Module):
|
|
||||||
# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5
|
|
||||||
#
|
def act_with_gamma(act_type, gamma: float = 1.):
|
||||||
# def __init__(self, inplace=False):
|
def _create(inplace=False):
|
||||||
# super().__init__()
|
return GammaAct(act_type, gamma=gamma, inplace=inplace)
|
||||||
# self.inplace = inplace
|
return _create
|
||||||
#
|
|
||||||
# def forward(self, x):
|
|
||||||
# return F.relu(x, inplace=self.inplace) * self._K
|
|
||||||
|
|
||||||
|
|
||||||
class DownsampleAvg(nn.Module):
|
class DownsampleAvg(nn.Module):
|
||||||
@ -178,10 +176,9 @@ class NormalizationFreeBlock(nn.Module):
|
|||||||
out_chs = out_chs or in_chs
|
out_chs = out_chs or in_chs
|
||||||
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
|
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
|
||||||
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
|
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
|
||||||
groups = 1
|
groups = 1 if group_size is None else mid_chs // group_size
|
||||||
if group_size is not None:
|
if group_size and group_size % ch_div == 0:
|
||||||
# NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand.
|
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
|
||||||
groups = mid_chs // group_size
|
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.attn_gain = attn_gain
|
self.attn_gain = attn_gain
|
||||||
@ -229,10 +226,11 @@ class NormalizationFreeBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
|
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
|
||||||
|
stem_stride = 2
|
||||||
stem = OrderedDict()
|
stem = OrderedDict()
|
||||||
assert stem_type in ('', 'deep', '3x3', '7x7')
|
assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
|
||||||
if 'deep' in stem_type:
|
if 'deep' in stem_type:
|
||||||
# 3 deep 3x3 conv stack as in ResNet V1D models
|
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
|
||||||
mid_chs = out_chs // 2
|
mid_chs = out_chs // 2
|
||||||
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||||
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||||
@ -244,12 +242,16 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
|
|||||||
# 7x7 stem conv as in ResNet
|
# 7x7 stem conv as in ResNet
|
||||||
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
||||||
|
|
||||||
return nn.Sequential(stem)
|
if 'pool' in stem_type:
|
||||||
|
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
|
||||||
|
stem_stride = 4
|
||||||
|
|
||||||
|
return nn.Sequential(stem), stem_stride
|
||||||
|
|
||||||
|
|
||||||
_nonlin_gamma = dict(
|
_nonlin_gamma = dict(
|
||||||
silu=.5595,
|
silu=1./.5595,
|
||||||
relu=(0.5 * (1. - 1. / math.pi)) ** 0.5,
|
relu=(0.5 * (1. - 1. / math.pi)) ** -0.5,
|
||||||
identity=1.0
|
identity=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -264,9 +266,12 @@ class NormalizerFreeNet(nn.Module):
|
|||||||
the (preact) ResNet models described earlier in the paper.
|
the (preact) ResNet models described earlier in the paper.
|
||||||
|
|
||||||
There are a few differences:
|
There are a few differences:
|
||||||
* channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts
|
* channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),
|
||||||
|
this changes channel dim and param counts slightly from the paper models
|
||||||
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
|
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
|
||||||
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
|
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
|
||||||
|
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
|
||||||
|
apply it in each activation. This is slightly slower, and yields slightly different results.
|
||||||
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
|
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
|
||||||
for what it is/does. Approx 8-10% throughput loss.
|
for what it is/does. Approx 8-10% throughput loss.
|
||||||
"""
|
"""
|
||||||
@ -275,29 +280,33 @@ class NormalizerFreeNet(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
act_layer = get_act_layer(cfg.act_layer)
|
|
||||||
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
|
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
|
||||||
|
if cfg.gamma_in_act:
|
||||||
|
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
|
||||||
|
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True)
|
||||||
|
else:
|
||||||
|
act_layer = get_act_layer(cfg.act_layer)
|
||||||
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
|
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
|
||||||
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||||
|
|
||||||
self.feature_info = [] # FIXME fill out feature info
|
|
||||||
|
|
||||||
stem_chs = cfg.stem_chs or cfg.channels[0]
|
stem_chs = cfg.stem_chs or cfg.channels[0]
|
||||||
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
|
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
|
||||||
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
|
self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
|
||||||
|
|
||||||
prev_chs = stem_chs
|
self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4
|
||||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
||||||
net_stride = 2
|
prev_chs = stem_chs
|
||||||
|
net_stride = stem_stride
|
||||||
dilation = 1
|
dilation = 1
|
||||||
expected_var = 1.0
|
expected_var = 1.0
|
||||||
stages = []
|
stages = []
|
||||||
for stage_idx, stage_depth in enumerate(cfg.depths):
|
for stage_idx, stage_depth in enumerate(cfg.depths):
|
||||||
if net_stride >= output_stride:
|
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
|
||||||
dilation *= 2
|
self.feature_info += [dict(
|
||||||
|
num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')]
|
||||||
|
if net_stride >= output_stride and stride > 1:
|
||||||
|
dilation *= stride
|
||||||
stride = 1
|
stride = 1
|
||||||
else:
|
|
||||||
stride = 2
|
|
||||||
net_stride *= stride
|
net_stride *= stride
|
||||||
first_dilation = 1 if dilation in (1, 2) else 2
|
first_dilation = 1 if dilation in (1, 2) else 2
|
||||||
|
|
||||||
@ -338,7 +347,10 @@ class NormalizerFreeNet(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.num_features = prev_chs
|
self.num_features = prev_chs
|
||||||
self.final_conv = nn.Identity()
|
self.final_conv = nn.Identity()
|
||||||
|
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv
|
||||||
self.final_act = act_layer()
|
self.final_act = act_layer()
|
||||||
|
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
|
||||||
|
|
||||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
for n, m in self.named_modules():
|
for n, m in self.named_modules():
|
||||||
@ -373,11 +385,14 @@ class NormalizerFreeNet(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _create_normfreenet(variant, pretrained=False, **kwargs):
|
def _create_normfreenet(variant, pretrained=False, **kwargs):
|
||||||
|
model_cfg = model_cfgs[variant]
|
||||||
feature_cfg = dict(flatten_sequential=True)
|
feature_cfg = dict(flatten_sequential=True)
|
||||||
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
|
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
|
||||||
|
if 'pool' in model_cfg.stem_type:
|
||||||
|
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfgs[variant], default_cfg=default_cfgs[variant],
|
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfg, default_cfg=default_cfgs[variant],
|
||||||
feature_cfg=feature_cfg, **kwargs)
|
feature_cfg=feature_cfg, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -412,30 +427,30 @@ def nf_regnet_b5(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def nf_resnet26d(pretrained=False, **kwargs):
|
def nf_resnet26(pretrained=False, **kwargs):
|
||||||
return _create_normfreenet('nf_resnet26d', pretrained=pretrained, **kwargs)
|
return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def nf_resnet50d(pretrained=False, **kwargs):
|
def nf_resnet50(pretrained=False, **kwargs):
|
||||||
return _create_normfreenet('nf_resnet50d', pretrained=pretrained, **kwargs)
|
return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def nf_seresnet26d(pretrained=False, **kwargs):
|
def nf_seresnet26(pretrained=False, **kwargs):
|
||||||
return _create_normfreenet('nf_seresnet26d', pretrained=pretrained, **kwargs)
|
return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def nf_seresnet50d(pretrained=False, **kwargs):
|
def nf_seresnet50(pretrained=False, **kwargs):
|
||||||
return _create_normfreenet('nf_seresnet50d', pretrained=pretrained, **kwargs)
|
return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def nf_ecaresnet26d(pretrained=False, **kwargs):
|
def nf_ecaresnet26(pretrained=False, **kwargs):
|
||||||
return _create_normfreenet('nf_ecaresnet26d', pretrained=pretrained, **kwargs)
|
return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def nf_ecaresnet50d(pretrained=False, **kwargs):
|
def nf_ecaresnet50(pretrained=False, **kwargs):
|
||||||
return _create_normfreenet('nf_ecaresnet50d', pretrained=pretrained, **kwargs)
|
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user