mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add NFNet-F model weights ported from DeepMind Haiku impl and new set of models w/ compatible config.
This commit is contained in:
parent
4ea5931964
commit
678ba4e0a2
14
README.md
14
README.md
@ -2,6 +2,20 @@
|
||||
|
||||
## What's New
|
||||
|
||||
### Feb 18, 2021
|
||||
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
|
||||
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
|
||||
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
|
||||
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
|
||||
* Matching the original pre-processing as closely as possible I get these results:
|
||||
* `dm_nfnet_f6` - 86.352
|
||||
* `dm_nfnet_f5` - 86.100
|
||||
* `dm_nfnet_f4` - 85.834
|
||||
* `dm_nfnet_f3` - 85.676
|
||||
* `dm_nfnet_f2` - 85.178
|
||||
* `dm_nfnet_f1` - 84.696
|
||||
* `dm_nfnet_f0` - 83.464
|
||||
|
||||
### Feb 16, 2021
|
||||
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
|
||||
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
|
||||
|
@ -29,6 +29,6 @@ from .separable_conv import SeparableConv2d, SeparableConvBnAct
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .split_attn import SplitAttnConv2d
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .weight_init import trunc_normal_
|
||||
|
@ -2,8 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .padding import get_padding
|
||||
from .conv2d_same import conv2d_same
|
||||
from .padding import get_padding, get_padding_value, pad_same
|
||||
|
||||
|
||||
def get_weight(module):
|
||||
@ -19,8 +18,8 @@ class StdConv2d(nn.Conv2d):
|
||||
https://arxiv.org/abs/1903.10520v2
|
||||
"""
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1,
|
||||
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
|
||||
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
|
||||
groups=1, bias=False, eps=1e-5):
|
||||
if padding is None:
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
@ -45,10 +44,13 @@ class StdConv2dSame(nn.Conv2d):
|
||||
https://arxiv.org/abs/1903.10520v2
|
||||
"""
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
|
||||
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
|
||||
groups=1, bias=False, eps=1e-5):
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
|
||||
super().__init__(
|
||||
in_channel, out_channels, kernel_size, stride=stride,
|
||||
padding=0, dilation=dilation, groups=groups, bias=bias)
|
||||
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
groups=groups, bias=bias)
|
||||
self.same_pad = is_dynamic
|
||||
self.eps = eps
|
||||
|
||||
def get_weight(self):
|
||||
@ -57,7 +59,9 @@ class StdConv2dSame(nn.Conv2d):
|
||||
return weight
|
||||
|
||||
def forward(self, x):
|
||||
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
if self.same_pad:
|
||||
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
||||
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
@ -68,17 +72,18 @@ class ScaledStdConv2d(nn.Conv2d):
|
||||
https://arxiv.org/abs/2101.08692
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
|
||||
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
|
||||
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
|
||||
if padding is None:
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
|
||||
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
groups=groups, bias=bias)
|
||||
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
|
||||
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
|
||||
self.eps = eps ** 2 if use_layernorm else eps
|
||||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
|
||||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
|
||||
|
||||
def get_weight(self):
|
||||
if self.use_layernorm:
|
||||
@ -86,9 +91,52 @@ class ScaledStdConv2d(nn.Conv2d):
|
||||
else:
|
||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = self.scale * (self.weight - mean) / (std + self.eps)
|
||||
if self.gain is not None:
|
||||
weight = weight * self.gain
|
||||
return weight
|
||||
return self.gain * weight
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class ScaledStdConv2dSame(nn.Conv2d):
|
||||
"""Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support
|
||||
|
||||
NOTE: operations and default eps slightly changed from non-SAME impl to closer match Deepmind Haiku impl.
|
||||
Fore the sake of completeness, numeric differences are minor with arprox .005 top-1 difference.
|
||||
|
||||
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
|
||||
https://arxiv.org/abs/2101.08692
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
|
||||
bias=True, gamma=1.0, eps=1e-5, use_layernorm=False):
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
groups=groups, bias=bias)
|
||||
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
|
||||
self.scale = gamma * self.weight[0].numel() ** -0.5
|
||||
self.same_pad = is_dynamic
|
||||
self.eps = eps ** 2 if use_layernorm else eps
|
||||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
|
||||
|
||||
# NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem
|
||||
# to make much numerical difference (+/- .002 to .004) in top-1 during eval.
|
||||
# def get_weight(self):
|
||||
# var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
# scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain
|
||||
# weight = (self.weight - mean) * scale
|
||||
# return self.gain * weight
|
||||
|
||||
def get_weight(self):
|
||||
if self.use_layernorm:
|
||||
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
||||
else:
|
||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = self.scale * (self.weight - mean) / (std + self.eps)
|
||||
return self.gain * weight
|
||||
|
||||
def forward(self, x):
|
||||
if self.same_pad:
|
||||
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
||||
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
@ -24,12 +24,12 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn
|
||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
|
||||
get_act_layer, get_act_fn, get_attn, make_divisible
|
||||
|
||||
|
||||
def _dcfg(url='', **kwargs):
|
||||
@ -38,75 +38,102 @@ def _dcfg(url='', **kwargs):
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
dm_nfnet_f0=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth',
|
||||
pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9),
|
||||
dm_nfnet_f1=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth',
|
||||
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91),
|
||||
dm_nfnet_f2=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth',
|
||||
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92),
|
||||
dm_nfnet_f3=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth',
|
||||
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94),
|
||||
dm_nfnet_f4=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth',
|
||||
pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951),
|
||||
dm_nfnet_f5=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth',
|
||||
pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954),
|
||||
dm_nfnet_f6=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth',
|
||||
pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956),
|
||||
|
||||
nfnet_f0=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
nfnet_f1=_dcfg(
|
||||
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
|
||||
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
|
||||
nfnet_f2=_dcfg(
|
||||
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
|
||||
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
|
||||
nfnet_f3=_dcfg(
|
||||
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
|
||||
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
|
||||
nfnet_f4=_dcfg(
|
||||
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
|
||||
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
|
||||
nfnet_f5=_dcfg(
|
||||
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
|
||||
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
|
||||
nfnet_f6=_dcfg(
|
||||
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
|
||||
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
|
||||
nfnet_f7=_dcfg(
|
||||
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
|
||||
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
|
||||
|
||||
nfnet_f0s=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
nfnet_f1s=_dcfg(
|
||||
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), first_conv='stem.conv1'),
|
||||
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
|
||||
nfnet_f2s=_dcfg(
|
||||
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), first_conv='stem.conv1'),
|
||||
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
|
||||
nfnet_f3s=_dcfg(
|
||||
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), first_conv='stem.conv1'),
|
||||
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
|
||||
nfnet_f4s=_dcfg(
|
||||
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), first_conv='stem.conv1'),
|
||||
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
|
||||
nfnet_f5s=_dcfg(
|
||||
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), first_conv='stem.conv1'),
|
||||
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
|
||||
nfnet_f6s=_dcfg(
|
||||
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), first_conv='stem.conv1'),
|
||||
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
|
||||
nfnet_f7s=_dcfg(
|
||||
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608), first_conv='stem.conv1'),
|
||||
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
|
||||
|
||||
nfnet_l0a=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
nfnet_l0b=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
nfnet_l0c=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv1'),
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
|
||||
nf_regnet_b0=_dcfg(url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
nf_regnet_b0=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
|
||||
nf_regnet_b1=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
|
||||
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288)), # NOT to paper spec
|
||||
nf_regnet_b2=_dcfg(url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272)),
|
||||
nf_regnet_b3=_dcfg(url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320)),
|
||||
nf_regnet_b4=_dcfg(url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384)),
|
||||
nf_regnet_b5=_dcfg(url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456)),
|
||||
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec
|
||||
nf_regnet_b2=_dcfg(
|
||||
url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'),
|
||||
nf_regnet_b3=_dcfg(
|
||||
url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'),
|
||||
nf_regnet_b4=_dcfg(
|
||||
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'),
|
||||
nf_regnet_b5=_dcfg(
|
||||
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'),
|
||||
|
||||
nf_resnet26=_dcfg(url=''),
|
||||
nf_resnet26=_dcfg(url='', first_conv='stem.conv'),
|
||||
nf_resnet50=_dcfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
|
||||
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94),
|
||||
nf_resnet101=_dcfg(url=''),
|
||||
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'),
|
||||
nf_resnet101=_dcfg(url='', first_conv='stem.conv'),
|
||||
|
||||
nf_seresnet26=_dcfg(url=''),
|
||||
nf_seresnet50=_dcfg(url=''),
|
||||
nf_seresnet101=_dcfg(url=''),
|
||||
nf_seresnet26=_dcfg(url='', first_conv='stem.conv'),
|
||||
nf_seresnet50=_dcfg(url='', first_conv='stem.conv'),
|
||||
nf_seresnet101=_dcfg(url='', first_conv='stem.conv'),
|
||||
|
||||
nf_ecaresnet26=_dcfg(url=''),
|
||||
nf_ecaresnet50=_dcfg(url=''),
|
||||
nf_ecaresnet101=_dcfg(url=''),
|
||||
nf_ecaresnet26=_dcfg(url='', first_conv='stem.conv'),
|
||||
nf_ecaresnet50=_dcfg(url='', first_conv='stem.conv'),
|
||||
nf_ecaresnet101=_dcfg(url='', first_conv='stem.conv'),
|
||||
)
|
||||
|
||||
|
||||
@ -115,7 +142,6 @@ class NfCfg:
|
||||
depths: Tuple[int, int, int, int]
|
||||
channels: Tuple[int, int, int, int]
|
||||
alpha: float = 0.2
|
||||
gamma_in_act: bool = False
|
||||
stem_type: str = '3x3'
|
||||
stem_chs: Optional[int] = None
|
||||
group_size: Optional[int] = None
|
||||
@ -128,6 +154,8 @@ class NfCfg:
|
||||
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
|
||||
reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
|
||||
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
|
||||
gamma_in_act: bool = False
|
||||
same_padding: bool = False
|
||||
skipinit: bool = False # disabled by default, non-trivial performance impact
|
||||
zero_init_fc: bool = False
|
||||
act_layer: str = 'silu'
|
||||
@ -163,8 +191,26 @@ def _nfnet_cfg(
|
||||
return cfg
|
||||
|
||||
|
||||
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
|
||||
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
|
||||
cfg = NfCfg(
|
||||
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
|
||||
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
|
||||
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs)
|
||||
return cfg
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
# NFNet-F models w/ GeLU
|
||||
# NFNet-F models w/ GELU compatible with DeepMind weights
|
||||
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
|
||||
dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)),
|
||||
dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)),
|
||||
dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)),
|
||||
dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)),
|
||||
dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)),
|
||||
dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)),
|
||||
|
||||
# NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU)
|
||||
nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
|
||||
nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
|
||||
nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
|
||||
@ -229,7 +275,7 @@ class GammaAct(nn.Module):
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return self.gamma * self.act_fn(x, inplace=self.inplace)
|
||||
return self.act_fn(x, inplace=self.inplace).mul_(self.gamma)
|
||||
|
||||
|
||||
def act_with_gamma(act_type, gamma: float = 1.):
|
||||
@ -325,8 +371,7 @@ class NormFreeBlock(nn.Module):
|
||||
out = self.drop_path(out)
|
||||
|
||||
if self.skipinit_gain is not None:
|
||||
# this really slows things down for some reason, TBD
|
||||
out = out * self.skipinit_gain
|
||||
out.mul_(self.skipinit_gain) # this slows things down more than expected, TBD
|
||||
out = out * self.alpha + shortcut
|
||||
return out
|
||||
|
||||
@ -419,12 +464,13 @@ class NormFreeNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
|
||||
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
|
||||
if cfg.gamma_in_act:
|
||||
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
|
||||
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True)
|
||||
conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps
|
||||
else:
|
||||
act_layer = get_act_layer(cfg.act_layer)
|
||||
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
|
||||
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer])
|
||||
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||
|
||||
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
|
||||
@ -538,6 +584,69 @@ def _create_normfreenet(variant, pretrained=False, **kwargs):
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def dm_nfnet_f0(pretrained=False, **kwargs):
|
||||
""" NFNet-F0 (DeepMind weight compatible)
|
||||
`High-Performance Large-Scale Image Recognition Without Normalization`
|
||||
- https://arxiv.org/abs/2102.06171
|
||||
"""
|
||||
return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def dm_nfnet_f1(pretrained=False, **kwargs):
|
||||
""" NFNet-F1 (DeepMind weight compatible)
|
||||
`High-Performance Large-Scale Image Recognition Without Normalization`
|
||||
- https://arxiv.org/abs/2102.06171
|
||||
"""
|
||||
return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def dm_nfnet_f2(pretrained=False, **kwargs):
|
||||
""" NFNet-F2 (DeepMind weight compatible)
|
||||
`High-Performance Large-Scale Image Recognition Without Normalization`
|
||||
- https://arxiv.org/abs/2102.06171
|
||||
"""
|
||||
return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def dm_nfnet_f3(pretrained=False, **kwargs):
|
||||
""" NFNet-F3 (DeepMind weight compatible)
|
||||
`High-Performance Large-Scale Image Recognition Without Normalization`
|
||||
- https://arxiv.org/abs/2102.06171
|
||||
"""
|
||||
return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def dm_nfnet_f4(pretrained=False, **kwargs):
|
||||
""" NFNet-F4 (DeepMind weight compatible)
|
||||
`High-Performance Large-Scale Image Recognition Without Normalization`
|
||||
- https://arxiv.org/abs/2102.06171
|
||||
"""
|
||||
return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def dm_nfnet_f5(pretrained=False, **kwargs):
|
||||
""" NFNet-F5 (DeepMind weight compatible)
|
||||
`High-Performance Large-Scale Image Recognition Without Normalization`
|
||||
- https://arxiv.org/abs/2102.06171
|
||||
"""
|
||||
return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def dm_nfnet_f6(pretrained=False, **kwargs):
|
||||
""" NFNet-F6 (DeepMind weight compatible)
|
||||
`High-Performance Large-Scale Image Recognition Without Normalization`
|
||||
- https://arxiv.org/abs/2102.06171
|
||||
"""
|
||||
return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def nfnet_f0(pretrained=False, **kwargs):
|
||||
""" NFNet-F0
|
||||
|
Loading…
x
Reference in New Issue
Block a user