mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Initial Normalizer-Free Reg/ResNet impl. A bit of related layer refactoring.
This commit is contained in:
parent
9a38416fbd
commit
5a8e1e643e
@ -11,6 +11,7 @@ from .inception_v3 import *
|
|||||||
from .inception_v4 import *
|
from .inception_v4 import *
|
||||||
from .mobilenetv3 import *
|
from .mobilenetv3 import *
|
||||||
from .nasnet import *
|
from .nasnet import *
|
||||||
|
from .nfnet import *
|
||||||
from .pnasnet import *
|
from .pnasnet import *
|
||||||
from .regnet import *
|
from .regnet import *
|
||||||
from .res2net import *
|
from .res2net import *
|
||||||
|
@ -10,13 +10,13 @@ from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set
|
|||||||
from .conv2d_same import Conv2dSame, conv2d_same
|
from .conv2d_same import Conv2dSame, conv2d_same
|
||||||
from .conv_bn_act import ConvBnAct
|
from .conv_bn_act import ConvBnAct
|
||||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||||
from .create_attn import create_attn
|
from .create_attn import get_attn, create_attn
|
||||||
from .create_conv2d import create_conv2d
|
from .create_conv2d import create_conv2d
|
||||||
from .create_norm_act import create_norm_act, get_norm_act_layer
|
from .create_norm_act import create_norm_act, get_norm_act_layer
|
||||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||||
from .eca import EcaModule, CecaModule
|
from .eca import EcaModule, CecaModule
|
||||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
|
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
|
||||||
from .inplace_abn import InplaceAbn
|
from .inplace_abn import InplaceAbn
|
||||||
from .linear import Linear
|
from .linear import Linear
|
||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
@ -29,5 +29,6 @@ from .separable_conv import SeparableConv2d, SeparableConvBnAct
|
|||||||
from .space_to_depth import SpaceToDepthModule
|
from .space_to_depth import SpaceToDepthModule
|
||||||
from .split_attn import SplitAttnConv2d
|
from .split_attn import SplitAttnConv2d
|
||||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
|
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
|
||||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_
|
||||||
|
@ -8,7 +8,7 @@ from .eca import EcaModule, CecaModule
|
|||||||
from .cbam import CbamModule, LightCbamModule
|
from .cbam import CbamModule, LightCbamModule
|
||||||
|
|
||||||
|
|
||||||
def create_attn(attn_type, channels, **kwargs):
|
def get_attn(attn_type):
|
||||||
module_cls = None
|
module_cls = None
|
||||||
if attn_type is not None:
|
if attn_type is not None:
|
||||||
if isinstance(attn_type, str):
|
if isinstance(attn_type, str):
|
||||||
@ -32,6 +32,12 @@ def create_attn(attn_type, channels, **kwargs):
|
|||||||
module_cls = SEModule
|
module_cls = SEModule
|
||||||
else:
|
else:
|
||||||
module_cls = attn_type
|
module_cls = attn_type
|
||||||
|
return module_cls
|
||||||
|
|
||||||
|
|
||||||
|
def create_attn(attn_type, channels, **kwargs):
|
||||||
|
module_cls = get_attn(attn_type)
|
||||||
if module_cls is not None:
|
if module_cls is not None:
|
||||||
|
# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
|
||||||
return module_cls(channels, **kwargs)
|
return module_cls(channels, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
@ -22,6 +22,10 @@ to_4tuple = _ntuple(4)
|
|||||||
to_ntuple = _ntuple
|
to_ntuple = _ntuple
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(v, divisor=8, min_value=None):
|
||||||
|
min_value = min_value or divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
@ -1,13 +1,27 @@
|
|||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .create_act import create_act_layer
|
from .create_act import create_act_layer
|
||||||
|
from .helpers import make_divisible
|
||||||
|
|
||||||
|
|
||||||
class SEModule(nn.Module):
|
class SEModule(nn.Module):
|
||||||
|
""" SE Module as defined in original SE-Nets with a few additions
|
||||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
|
Additions include:
|
||||||
gate_layer='sigmoid'):
|
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
|
||||||
|
* divisor can be specified to keep channels rounded to specified values (default: 1)
|
||||||
|
* reduction channels can be specified directly by arg (if reduction_channels is set)
|
||||||
|
* reduction channels can be specified by float ratio (if reduction_ratio is set)
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
|
||||||
|
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
|
||||||
super(SEModule, self).__init__()
|
super(SEModule, self).__init__()
|
||||||
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
|
if reduction_channels is not None:
|
||||||
|
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
|
||||||
|
elif reduction_ratio is not None:
|
||||||
|
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
|
||||||
|
else:
|
||||||
|
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
|
||||||
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
|
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
|
||||||
self.act = act_layer(inplace=True)
|
self.act = act_layer(inplace=True)
|
||||||
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
|
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
|
||||||
|
91
timm/models/layers/std_conv.py
Normal file
91
timm/models/layers/std_conv.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .padding import get_padding
|
||||||
|
from .conv2d_same import conv2d_same
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight(module):
|
||||||
|
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||||
|
weight = (module.weight - mean) / (std + module.eps)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
class StdConv2d(nn.Conv2d):
|
||||||
|
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
|
||||||
|
|
||||||
|
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
|
||||||
|
https://arxiv.org/abs/1903.10520v2
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self, in_channel, out_channels, kernel_size, stride=1,
|
||||||
|
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
|
||||||
|
if padding is None:
|
||||||
|
padding = get_padding(kernel_size, stride, dilation)
|
||||||
|
super().__init__(
|
||||||
|
in_channel, out_channels, kernel_size, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||||
|
weight = (self.weight - mean) / (std + self.eps)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StdConv2dSame(nn.Conv2d):
|
||||||
|
"""Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
|
||||||
|
|
||||||
|
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
|
||||||
|
https://arxiv.org/abs/1903.10520v2
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
|
||||||
|
super().__init__(
|
||||||
|
in_channel, out_channels, kernel_size, stride=stride,
|
||||||
|
padding=0, dilation=dilation, groups=groups, bias=bias)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||||
|
weight = (self.weight - mean) / (std + self.eps)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledStdConv2d(nn.Conv2d):
|
||||||
|
"""Conv2d layer with Scaled Weight Standardization.
|
||||||
|
|
||||||
|
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
|
||||||
|
https://arxiv.org/abs/2101.08692
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size,
|
||||||
|
stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5):
|
||||||
|
if padding is None:
|
||||||
|
padding = get_padding(kernel_size, stride, dilation)
|
||||||
|
super().__init__(
|
||||||
|
in_channels, out_channels, kernel_size, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||||
|
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
|
||||||
|
self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||||
|
weight = (self.weight - mean) / (self.gamma * std + self.eps)
|
||||||
|
if self.gain is not None:
|
||||||
|
weight = weight * self.gain
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
441
timm/models/nfnet.py
Normal file
441
timm/models/nfnet.py
Normal file
@ -0,0 +1,441 @@
|
|||||||
|
""" Normalizer Free RegNet / ResNet (pre-activation) Models
|
||||||
|
|
||||||
|
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
|
||||||
|
- https://arxiv.org/abs/2101.08692
|
||||||
|
|
||||||
|
Hacked together by / copyright Ross Wightman, 2021.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .helpers import build_model_with_cfg
|
||||||
|
from .registry import register_model
|
||||||
|
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible
|
||||||
|
|
||||||
|
|
||||||
|
def _dcfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
# FIXME finish
|
||||||
|
default_cfgs = {
|
||||||
|
'nf_regnet_b0': _dcfg(url=''),
|
||||||
|
'nf_regnet_b1': _dcfg(url='', input_size=(3, 240, 240)),
|
||||||
|
'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256)),
|
||||||
|
'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272)),
|
||||||
|
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)),
|
||||||
|
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)),
|
||||||
|
|
||||||
|
'nf_resnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
'nf_resnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
'nf_resnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
|
||||||
|
'nf_seresnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
'nf_seresnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
'nf_seresnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
|
||||||
|
'nf_ecaresnet26d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
'nf_ecaresnet50d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
'nf_ecaresnet101d': _dcfg(url='', first_conv='stem.conv1'),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NfCfg:
|
||||||
|
depths: Tuple[int, int, int, int]
|
||||||
|
channels: Tuple[int, int, int, int]
|
||||||
|
alpha: float = 0.2
|
||||||
|
stem_type: str = '3x3'
|
||||||
|
stem_chs: Optional[int] = None
|
||||||
|
group_size: Optional[int] = 8
|
||||||
|
attn_layer: Optional[str] = 'se'
|
||||||
|
attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8))
|
||||||
|
attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
|
||||||
|
width_factor: float = 0.75
|
||||||
|
bottle_ratio: float = 2.25
|
||||||
|
efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models
|
||||||
|
num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode)
|
||||||
|
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
|
||||||
|
skipinit: bool = False
|
||||||
|
act_layer: str = 'silu'
|
||||||
|
|
||||||
|
|
||||||
|
model_cfgs = dict(
|
||||||
|
# EffNet influenced RegNet defs
|
||||||
|
nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280),
|
||||||
|
nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280),
|
||||||
|
nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416),
|
||||||
|
nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536),
|
||||||
|
nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792),
|
||||||
|
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
|
||||||
|
|
||||||
|
# ResNet (preact, D style deep stem/avg down) defs
|
||||||
|
nf_resnet26d=NfCfg(
|
||||||
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer=None,),
|
||||||
|
nf_resnet50d=NfCfg(
|
||||||
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer=None),
|
||||||
|
nf_resnet101d=NfCfg(
|
||||||
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer=None),
|
||||||
|
|
||||||
|
|
||||||
|
nf_seresnet26d=NfCfg(
|
||||||
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||||
|
nf_seresnet50d=NfCfg(
|
||||||
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||||
|
nf_seresnet101d=NfCfg(
|
||||||
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
|
||||||
|
|
||||||
|
|
||||||
|
nf_ecaresnet26d=NfCfg(
|
||||||
|
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||||
|
nf_ecaresnet50d=NfCfg(
|
||||||
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||||
|
nf_ecaresnet101d=NfCfg(
|
||||||
|
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
|
||||||
|
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
|
||||||
|
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
# class NormFreeSiLU(nn.Module):
|
||||||
|
# _K = 1. / 0.5595
|
||||||
|
# def __init__(self, inplace=False):
|
||||||
|
# super().__init__()
|
||||||
|
# self.inplace = inplace
|
||||||
|
#
|
||||||
|
# def forward(self, x):
|
||||||
|
# return F.silu(x, inplace=self.inplace) * self._K
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class NormFreeReLU(nn.Module):
|
||||||
|
# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5
|
||||||
|
#
|
||||||
|
# def __init__(self, inplace=False):
|
||||||
|
# super().__init__()
|
||||||
|
# self.inplace = inplace
|
||||||
|
#
|
||||||
|
# def forward(self, x):
|
||||||
|
# return F.relu(x, inplace=self.inplace) * self._K
|
||||||
|
|
||||||
|
|
||||||
|
class DownsampleAvg(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d):
|
||||||
|
""" AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
|
||||||
|
super(DownsampleAvg, self).__init__()
|
||||||
|
avg_stride = stride if dilation == 1 else 1
|
||||||
|
if stride > 1 or dilation > 1:
|
||||||
|
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||||
|
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||||
|
else:
|
||||||
|
self.pool = nn.Identity()
|
||||||
|
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(self.pool(x))
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizationFreeBlock(nn.Module):
|
||||||
|
"""Normalization-free pre-activation block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
|
||||||
|
alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None,
|
||||||
|
attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False):
|
||||||
|
super().__init__()
|
||||||
|
first_dilation = first_dilation or dilation
|
||||||
|
out_chs = out_chs or in_chs
|
||||||
|
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
|
||||||
|
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
|
||||||
|
groups = 1
|
||||||
|
if group_size is not None:
|
||||||
|
# NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand.
|
||||||
|
groups = mid_chs // group_size
|
||||||
|
self.alpha = alpha
|
||||||
|
self.beta = beta
|
||||||
|
self.attn_gain = attn_gain
|
||||||
|
|
||||||
|
if in_chs != out_chs or stride != 1 or dilation != first_dilation:
|
||||||
|
self.downsample = DownsampleAvg(
|
||||||
|
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
self.act1 = act_layer()
|
||||||
|
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
||||||
|
self.act2 = act_layer(inplace=True)
|
||||||
|
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
||||||
|
if attn_layer is not None:
|
||||||
|
self.attn = attn_layer(mid_chs)
|
||||||
|
else:
|
||||||
|
self.attn = None
|
||||||
|
self.act3 = act_layer()
|
||||||
|
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||||
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||||
|
self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.act1(x) * self.beta
|
||||||
|
|
||||||
|
# shortcut branch
|
||||||
|
shortcut = x
|
||||||
|
if self.downsample is not None:
|
||||||
|
shortcut = self.downsample(out)
|
||||||
|
|
||||||
|
# residual branch
|
||||||
|
out = self.conv1(out)
|
||||||
|
out = self.conv2(self.act2(out))
|
||||||
|
if self.attn is not None:
|
||||||
|
out = self.attn_gain * self.attn(out)
|
||||||
|
out = self.conv3(self.act3(out))
|
||||||
|
out = self.drop_path(out)
|
||||||
|
if self.skipinit_gain is None:
|
||||||
|
out = out * self.alpha + shortcut
|
||||||
|
else:
|
||||||
|
# this really slows things down for some reason, TBD
|
||||||
|
out = out * self.alpha * self.skipinit_gain + shortcut
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
|
||||||
|
stem = OrderedDict()
|
||||||
|
assert stem_type in ('', 'deep', '3x3', '7x7')
|
||||||
|
if 'deep' in stem_type:
|
||||||
|
# 3 deep 3x3 conv stack as in ResNet V1D models
|
||||||
|
mid_chs = out_chs // 2
|
||||||
|
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||||
|
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||||
|
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
||||||
|
elif '3x3' in stem_type:
|
||||||
|
# 3x3 stem conv as in RegNet
|
||||||
|
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
|
||||||
|
else:
|
||||||
|
# 7x7 stem conv as in ResNet
|
||||||
|
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
||||||
|
|
||||||
|
return nn.Sequential(stem)
|
||||||
|
|
||||||
|
|
||||||
|
_nonlin_gamma = dict(
|
||||||
|
silu=.5595,
|
||||||
|
relu=(0.5 * (1. - 1. / math.pi)) ** 0.5,
|
||||||
|
identity=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizerFreeNet(nn.Module):
|
||||||
|
""" Normalizer-free ResNets and RegNets
|
||||||
|
|
||||||
|
As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
|
||||||
|
- https://arxiv.org/abs/2101.08692
|
||||||
|
|
||||||
|
This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and
|
||||||
|
the (preact) ResNet models described earlier in the paper.
|
||||||
|
|
||||||
|
There are a few differences:
|
||||||
|
* channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts
|
||||||
|
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
|
||||||
|
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
|
||||||
|
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
|
||||||
|
for what it is/does. Approx 8-10% throughput loss.
|
||||||
|
"""
|
||||||
|
def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||||
|
drop_rate=0., drop_path_rate=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
act_layer = get_act_layer(cfg.act_layer)
|
||||||
|
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
|
||||||
|
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
|
||||||
|
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||||
|
|
||||||
|
self.feature_info = [] # FIXME fill out feature info
|
||||||
|
|
||||||
|
stem_chs = cfg.stem_chs or cfg.channels[0]
|
||||||
|
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
|
||||||
|
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
|
||||||
|
|
||||||
|
prev_chs = stem_chs
|
||||||
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
||||||
|
net_stride = 2
|
||||||
|
dilation = 1
|
||||||
|
expected_var = 1.0
|
||||||
|
stages = []
|
||||||
|
for stage_idx, stage_depth in enumerate(cfg.depths):
|
||||||
|
if net_stride >= output_stride:
|
||||||
|
dilation *= 2
|
||||||
|
stride = 1
|
||||||
|
else:
|
||||||
|
stride = 2
|
||||||
|
net_stride *= stride
|
||||||
|
first_dilation = 1 if dilation in (1, 2) else 2
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for block_idx in range(cfg.depths[stage_idx]):
|
||||||
|
first_block = block_idx == 0 and stage_idx == 0
|
||||||
|
out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
|
||||||
|
blocks += [NormalizationFreeBlock(
|
||||||
|
in_chs=prev_chs, out_chs=out_chs,
|
||||||
|
alpha=cfg.alpha,
|
||||||
|
beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block
|
||||||
|
stride=stride if block_idx == 0 else 1,
|
||||||
|
dilation=dilation,
|
||||||
|
first_dilation=first_dilation,
|
||||||
|
group_size=cfg.group_size,
|
||||||
|
bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio,
|
||||||
|
efficient=cfg.efficient,
|
||||||
|
ch_div=cfg.ch_div,
|
||||||
|
attn_layer=attn_layer,
|
||||||
|
attn_gain=cfg.attn_gain,
|
||||||
|
act_layer=act_layer,
|
||||||
|
conv_layer=conv_layer,
|
||||||
|
drop_path_rate=dpr[stage_idx][block_idx],
|
||||||
|
skipinit=cfg.skipinit,
|
||||||
|
)]
|
||||||
|
if block_idx == 0:
|
||||||
|
expected_var = 1. # expected var is reset after first block of each stage
|
||||||
|
expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance
|
||||||
|
first_dilation = dilation
|
||||||
|
prev_chs = out_chs
|
||||||
|
stages += [nn.Sequential(*blocks)]
|
||||||
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
|
if cfg.efficient and cfg.num_features:
|
||||||
|
# The paper NFRegNet models have an EfficientNet-like final head convolution.
|
||||||
|
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
|
||||||
|
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
|
||||||
|
else:
|
||||||
|
self.num_features = prev_chs
|
||||||
|
self.final_conv = nn.Identity()
|
||||||
|
self.final_act = act_layer()
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
if 'fc' in n and isinstance(m, nn.Linear):
|
||||||
|
nn.init.zeros_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
# as per discussion with paper authors, original in haiku is
|
||||||
|
# hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.stem(x)
|
||||||
|
x = self.stages(x)
|
||||||
|
x = self.final_conv(x)
|
||||||
|
x = self.final_act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _create_normfreenet(variant, pretrained=False, **kwargs):
|
||||||
|
feature_cfg = dict(flatten_sequential=True)
|
||||||
|
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
|
||||||
|
|
||||||
|
return build_model_with_cfg(
|
||||||
|
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfgs[variant], default_cfg=default_cfgs[variant],
|
||||||
|
feature_cfg=feature_cfg, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_regnet_b0(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_regnet_b1(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_regnet_b2(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_regnet_b3(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_regnet_b4(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_regnet_b5(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_resnet26d(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_resnet26d', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_resnet50d(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_resnet50d', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_seresnet26d(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_seresnet26d', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_seresnet50d(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_seresnet50d', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_ecaresnet26d(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_ecaresnet26d', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def nf_ecaresnet50d(pretrained=False, **kwargs):
|
||||||
|
return _create_normfreenet('nf_ecaresnet50d', pretrained=pretrained, **kwargs)
|
@ -32,13 +32,12 @@ from collections import OrderedDict # pylint: disable=g-importing-member
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same
|
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
@ -112,43 +111,6 @@ def make_div(v, divisor=8):
|
|||||||
return new_v
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
class StdConv2d(nn.Conv2d):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
|
|
||||||
padding = get_padding(kernel_size, stride, dilation)
|
|
||||||
super().__init__(
|
|
||||||
in_channel, out_channels, kernel_size, stride=stride,
|
|
||||||
padding=padding, dilation=dilation, bias=bias, groups=groups)
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
w = self.weight
|
|
||||||
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
||||||
w = (w - m) / (torch.sqrt(v) + self.eps)
|
|
||||||
x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class StdConv2dSame(nn.Conv2d):
|
|
||||||
"""StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
|
|
||||||
padding = get_padding(kernel_size, stride, dilation)
|
|
||||||
super().__init__(
|
|
||||||
in_channel, out_channels, kernel_size, stride=stride,
|
|
||||||
padding=padding, dilation=dilation, bias=bias, groups=groups)
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
w = self.weight
|
|
||||||
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
||||||
w = (w - m) / (torch.sqrt(v) + self.eps)
|
|
||||||
x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def tf2th(conv_weights):
|
def tf2th(conv_weights):
|
||||||
"""Possibly convert HWIO to OIHW."""
|
"""Possibly convert HWIO to OIHW."""
|
||||||
if conv_weights.ndim == 4:
|
if conv_weights.ndim == 4:
|
||||||
|
@ -15,7 +15,7 @@ from math import ceil
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg
|
||||||
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath
|
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .efficientnet_builder import efficientnet_init_weights
|
from .efficientnet_builder import efficientnet_init_weights
|
||||||
|
|
||||||
@ -49,12 +49,6 @@ default_cfgs = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_divisible(v, divisor=8, min_value=None):
|
|
||||||
min_value = min_value or divisor
|
|
||||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
||||||
return new_v
|
|
||||||
|
|
||||||
|
|
||||||
class SEWithNorm(nn.Module):
|
class SEWithNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
|
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
|
||||||
|
@ -28,9 +28,9 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import DropPath, to_2tuple, trunc_normal_
|
from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
|
||||||
from .resnet import resnet26d, resnet50d
|
from .resnet import resnet26d, resnet50d
|
||||||
from .resnetv2 import ResNetV2, StdConv2dSame
|
from .resnetv2 import ResNetV2
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user