regnet.py multi-weight conversion, new ImageNet-12k pretrain/ft from timm for y_120 and y_160, also new tv v2, swag, & seer weights for push to Hf hub.

pull/1736/head
Ross Wightman 2023-03-21 15:51:49 -07:00
parent c78319adce
commit e7ef8335bf
1 changed files with 460 additions and 177 deletions

View File

@ -1,16 +1,26 @@
"""RegNet """RegNet X, Y, Z, and more
Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678 Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678
Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
Paper: `Fast and Accurate Model Scaling` - https://arxiv.org/abs/2103.06877
Original Impl: None
Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here) Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here)
and cleaned up with more descriptive variable names. and cleaned up with more descriptive variable names.
Weights from original impl have been modified Weights from original pycls impl have been modified:
* first layer from BGR -> RGB as most PyTorch models are * first layer from BGR -> RGB as most PyTorch models are
* removed training specific dict entries from checkpoints and keep model state_dict only * removed training specific dict entries from checkpoints and keep model state_dict only
* remap names to match the ones here * remap names to match the ones here
Supports weight loading from torchvision and classy-vision (incl VISSL SEER)
A number of custom timm model definitions additions including:
* stochastic depth, gradient checkpointing, layer-decay, configurable dilation
* a pre-activation 'V' variant
* only known RegNet-Z model definitions with pretrained weights
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import math import math
@ -24,10 +34,10 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq, named_apply from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this __all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this
@ -41,6 +51,7 @@ class RegNetCfg:
group_size: int = 24 group_size: int = 24
bottle_ratio: float = 1. bottle_ratio: float = 1.
se_ratio: float = 0. se_ratio: float = 0.
group_min_ratio: float = 0.
stem_width: int = 32 stem_width: int = 32
downsample: Optional[str] = 'conv1x1' downsample: Optional[str] = 'conv1x1'
linear_out: bool = False linear_out: bool = False
@ -50,178 +61,79 @@ class RegNetCfg:
norm_layer: Union[str, Callable] = 'batchnorm' norm_layer: Union[str, Callable] = 'batchnorm'
# Model FLOPS = three trailing digits * 10^8
model_cfgs = dict(
# RegNet-X
regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13),
regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22),
regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16),
regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16),
regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18),
regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25),
regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23),
regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17),
regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23),
regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19),
regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22),
regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23),
# RegNet-Y
regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25),
regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25),
regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25),
regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25),
regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25),
regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25),
regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25),
regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25),
regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25),
regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25),
regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25),
regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25),
regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25),
regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25),
regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25),
# Experimental
regnety_040s_gn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# regnetv = 'preact regnet y'
regnetv_040=RegNetCfg(
depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'),
regnetv_064=RegNetCfg(
depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu',
downsample='avg'),
# RegNet-Z (unverified)
regnetz_005=RegNetCfg(
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1024, act_layer='silu',
),
regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=0, act_layer='silu',
),
regnetz_040h=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1536, act_layer='silu',
),
)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
regnetx_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'),
regnetx_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'),
regnetx_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'),
regnetx_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'),
regnetx_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'),
regnetx_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'),
regnetx_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'),
regnetx_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'),
regnetx_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'),
regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'),
regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'),
regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'),
regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'),
regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'),
regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'),
regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'),
regnety_032=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_080=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
regnety_160=_cfg(
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_320=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
),
regnety_640=_cfg(url=''),
regnety_1280=_cfg(url=''),
regnety_2560=_cfg(url=''),
regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth',
first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)),
regnetv_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth',
first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)),
regnetz_005=_cfg(url=''),
regnetz_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
regnetz_040h=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
)
def quantize_float(f, q): def quantize_float(f, q):
"""Converts a float to closest non-zero int divisible by q.""" """Converts a float to the closest non-zero int divisible by q."""
return int(round(f / q) * q) return int(round(f / q) * q)
def adjust_widths_groups_comp(widths, bottle_ratios, groups): def adjust_widths_groups_comp(widths, bottle_ratios, groups, min_ratio=0.):
"""Adjusts the compatibility of widths and groups.""" """Adjusts the compatibility of widths and groups."""
bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)] bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)]
groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)]
bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)] if min_ratio:
# torchvision uses a different rounding scheme for ensuring bottleneck widths divisible by group widths
bottleneck_widths = [make_divisible(w_bot, g, min_ratio) for w_bot, g in zip(bottleneck_widths, groups)]
else:
bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)]
widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)] widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)]
return widths, groups return widths, groups
def generate_regnet(width_slope, width_initial, width_mult, depth, group_size, q=8): def generate_regnet(width_slope, width_initial, width_mult, depth, group_size, quant=8):
"""Generates per block widths from RegNet parameters.""" """Generates per block widths from RegNet parameters."""
assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0 assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % quant == 0
# TODO dWr scaling? # TODO dWr scaling?
# depth = int(depth * (scale ** 0.1)) # depth = int(depth * (scale ** 0.1))
# width_scale = scale ** 0.4 # dWr scale, exp 0.8 / 2, applied to both group and layer widths # width_scale = scale ** 0.4 # dWr scale, exp 0.8 / 2, applied to both group and layer widths
widths_cont = np.arange(depth) * width_slope + width_initial widths_cont = np.arange(depth) * width_slope + width_initial
width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult)) width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult))
widths = width_initial * np.power(width_mult, width_exps) widths = np.round(np.divide(width_initial * np.power(width_mult, width_exps), quant)) * quant
widths = np.round(np.divide(widths, q)) * q
num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1 num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1
groups = np.array([group_size for _ in range(num_stages)]) groups = np.array([group_size for _ in range(num_stages)])
return widths.astype(int).tolist(), num_stages, groups.astype(int).tolist() return widths.astype(int).tolist(), num_stages, groups.astype(int).tolist()
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False): def downsample_conv(
in_chs,
out_chs,
kernel_size=1,
stride=1,
dilation=1,
norm_layer=None,
preact=False,
):
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1 dilation = dilation if kernel_size > 1 else 1
if preact: if preact:
return create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation) return create_conv2d(
in_chs,
out_chs,
kernel_size,
stride=stride,
dilation=dilation,
)
else: else:
return ConvNormAct( return ConvNormAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False) in_chs,
out_chs,
kernel_size,
stride=stride,
dilation=dilation,
norm_layer=norm_layer,
apply_act=False,
)
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False): def downsample_avg(
in_chs,
out_chs,
kernel_size=1,
stride=1,
dilation=1,
norm_layer=None,
preact=False,
):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1 avg_stride = stride if dilation == 1 else 1
@ -290,8 +202,15 @@ class Bottleneck(nn.Module):
cargs = dict(act_layer=act_layer, norm_layer=norm_layer) cargs = dict(act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs)
self.conv2 = ConvNormAct( self.conv2 = ConvNormAct(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], bottleneck_chs,
groups=groups, drop_layer=drop_block, **cargs) bottleneck_chs,
kernel_size=3,
stride=stride,
dilation=dilation[0],
groups=groups,
drop_layer=drop_block,
**cargs,
)
if se_ratio: if se_ratio:
se_channels = int(round(in_chs * se_ratio)) se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
@ -299,7 +218,15 @@ class Bottleneck(nn.Module):
self.se = nn.Identity() self.se = nn.Identity()
self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs) self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs)
self.act3 = nn.Identity() if linear_out else act_layer() self.act3 = nn.Identity() if linear_out else act_layer()
self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, norm_layer=norm_layer) self.downsample = create_shortcut(
downsample,
in_chs,
out_chs,
kernel_size=1,
stride=stride,
dilation=dilation,
norm_layer=norm_layer,
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self): def zero_init_last(self):
@ -351,7 +278,13 @@ class PreBottleneck(nn.Module):
self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1) self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1)
self.norm2 = norm_act_layer(bottleneck_chs) self.norm2 = norm_act_layer(bottleneck_chs)
self.conv2 = create_conv2d( self.conv2 = create_conv2d(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], groups=groups) bottleneck_chs,
bottleneck_chs,
kernel_size=3,
stride=stride,
dilation=dilation[0],
groups=groups,
)
if se_ratio: if se_ratio:
se_channels = int(round(in_chs * se_ratio)) se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
@ -359,7 +292,15 @@ class PreBottleneck(nn.Module):
self.se = nn.Identity() self.se = nn.Identity()
self.norm3 = norm_act_layer(bottleneck_chs) self.norm3 = norm_act_layer(bottleneck_chs)
self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1) self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1)
self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, preact=True) self.downsample = create_shortcut(
downsample,
in_chs,
out_chs,
kernel_size=1,
stride=stride,
dilation=dilation,
preact=True,
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self): def zero_init_last(self):
@ -406,7 +347,8 @@ class RegStage(nn.Module):
dpr = drop_path_rates[i] if drop_path_rates is not None else 0. dpr = drop_path_rates[i] if drop_path_rates is not None else 0.
name = "b{}".format(i + 1) name = "b{}".format(i + 1)
self.add_module( self.add_module(
name, block_fn( name,
block_fn(
block_in_chs, block_in_chs,
out_chs, out_chs,
stride=block_stride, stride=block_stride,
@ -477,12 +419,23 @@ class RegNet(nn.Module):
prev_width = stem_width prev_width = stem_width
curr_stride = 2 curr_stride = 2
per_stage_args, common_args = self._get_stage_args( per_stage_args, common_args = self._get_stage_args(
cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) cfg,
output_stride=output_stride,
drop_path_rate=drop_path_rate,
)
assert len(per_stage_args) == 4 assert len(per_stage_args) == 4
block_fn = PreBottleneck if cfg.preact else Bottleneck block_fn = PreBottleneck if cfg.preact else Bottleneck
for i, stage_args in enumerate(per_stage_args): for i, stage_args in enumerate(per_stage_args):
stage_name = "s{}".format(i + 1) stage_name = "s{}".format(i + 1)
self.add_module(stage_name, RegStage(in_chs=prev_width, block_fn=block_fn, **stage_args, **common_args)) self.add_module(
stage_name,
RegStage(
in_chs=prev_width,
block_fn=block_fn,
**stage_args,
**common_args,
)
)
prev_width = stage_args['out_chs'] prev_width = stage_args['out_chs']
curr_stride *= stage_args['stride'] curr_stride *= stage_args['stride']
self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
@ -496,7 +449,11 @@ class RegNet(nn.Module):
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width self.num_features = prev_width
self.head = ClassifierHead( self.head = ClassifierHead(
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=self.num_features,
num_classes=num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@ -523,11 +480,13 @@ class RegNet(nn.Module):
stage_dpr = np.split(np.linspace(0, drop_path_rate, sum(stage_depths)), np.cumsum(stage_depths[:-1])) stage_dpr = np.split(np.linspace(0, drop_path_rate, sum(stage_depths)), np.cumsum(stage_depths[:-1]))
# Adjust the compatibility of ws and gws # Adjust the compatibility of ws and gws
stage_widths, stage_gs = adjust_widths_groups_comp(stage_widths, stage_br, stage_gs) stage_widths, stage_gs = adjust_widths_groups_comp(
stage_widths, stage_br, stage_gs, min_ratio=cfg.group_min_ratio)
arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates'] arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
per_stage_args = [ per_stage_args = [
dict(zip(arg_names, params)) for params in dict(zip(arg_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)] zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)
]
common_args = dict( common_args = dict(
downsample=cfg.downsample, downsample=cfg.downsample,
se_ratio=cfg.se_ratio, se_ratio=cfg.se_ratio,
@ -554,7 +513,7 @@ class RegNet(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
@ -590,7 +549,25 @@ def _init_weights(module, name='', zero_init_last=False):
def _filter_fn(state_dict): def _filter_fn(state_dict):
state_dict = state_dict.get('model', state_dict)
replaces = [
('f.a.0', 'conv1.conv'),
('f.a.1', 'conv1.bn'),
('f.b.0', 'conv2.conv'),
('f.b.1', 'conv2.bn'),
('f.final_bn', 'conv3.bn'),
('f.se.excitation.0', 'se.fc1'),
('f.se.excitation.2', 'se.fc2'),
('f.se', 'se'),
('f.c.0', 'conv3.conv'),
('f.c.1', 'conv3.bn'),
('f.c', 'conv3.conv'),
('proj.0', 'downsample.conv'),
('proj.1', 'downsample.bn'),
('proj', 'downsample.conv'),
]
if 'classy_state_dict' in state_dict: if 'classy_state_dict' in state_dict:
# classy-vision & vissl (SEER) weights
import re import re
state_dict = state_dict['classy_state_dict']['base_model']['model'] state_dict = state_dict['classy_state_dict']['base_model']['model']
out = {} out = {}
@ -601,15 +578,8 @@ def _filter_fn(state_dict):
r'^_feature_blocks.res\d.block(\d)-(\d+)', r'^_feature_blocks.res\d.block(\d)-(\d+)',
lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k) lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k)
k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k) k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k)
k = k.replace('proj', 'downsample.conv') for s, r in replaces:
k = k.replace('f.a.0', 'conv1.conv') k = k.replace(s, r)
k = k.replace('f.a.1', 'conv1.bn')
k = k.replace('f.b.0', 'conv2.conv')
k = k.replace('f.b.1', 'conv2.bn')
k = k.replace('f.c', 'conv3.conv')
k = k.replace('f.final_bn', 'conv3.bn')
k = k.replace('f.se.excitation.0', 'se.fc1')
k = k.replace('f.se.excitation.2', 'se.fc2')
out[k] = v out[k] = v
for k, v in state_dict['heads'].items(): for k, v in state_dict['heads'].items():
if 'projection_head' in k or 'prototypes' in k: if 'projection_head' in k or 'prototypes' in k:
@ -617,13 +587,89 @@ def _filter_fn(state_dict):
k = k.replace('0.clf.0', 'head.fc') k = k.replace('0.clf.0', 'head.fc')
out[k] = v out[k] = v
return out return out
if 'stem.0.weight' in state_dict:
if 'model' in state_dict: # torchvision weights
# For DeiT trained regnety_160 pretraiend model import re
state_dict = state_dict['model'] out = {}
for k, v in state_dict.items():
k = k.replace('stem.0', 'stem.conv')
k = k.replace('stem.1', 'stem.bn')
k = re.sub(
r'trunk_output.block(\d)\.block(\d+)\-(\d+)',
lambda x: f's{int(x.group(1))}.b{int(x.group(3)) + 1}', k)
for s, r in replaces:
k = k.replace(s, r)
k = k.replace('fc.', 'head.fc.')
out[k] = v
return out
return state_dict return state_dict
# Model FLOPS = three trailing digits * 10^8
model_cfgs = dict(
# RegNet-X
regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13),
regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22),
regnetx_004_tv=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22, group_min_ratio=0.9),
regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16),
regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16),
regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18),
regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25),
regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23),
regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17),
regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23),
regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19),
regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22),
regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23),
# RegNet-Y
regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25),
regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25),
regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25),
regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25),
regnety_008_tv=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25, group_min_ratio=0.9),
regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25),
regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25),
regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25),
regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25),
regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25),
regnety_080_tv=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25, group_min_ratio=0.9),
regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25),
regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25),
regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25),
regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25),
regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25),
regnety_2560=RegNetCfg(w0=640, wa=230.83, wm=2.53, group_size=373, depth=27, se_ratio=0.25),
#regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25),
# Experimental
regnety_040_sgn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# regnetv = 'preact regnet y'
regnetv_040=RegNetCfg(
depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'),
regnetv_064=RegNetCfg(
depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu',
downsample='avg'),
# RegNet-Z (unverified)
regnetz_005=RegNetCfg(
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1024, act_layer='silu',
),
regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=0, act_layer='silu',
),
regnetz_040_h=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1536, act_layer='silu',
),
)
def _create_regnet(variant, pretrained, **kwargs): def _create_regnet(variant, pretrained, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
RegNet, variant, pretrained, RegNet, variant, pretrained,
@ -632,6 +678,220 @@ def _create_regnet(variant, pretrained, **kwargs):
**kwargs) **kwargs)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'test_input_size': (3, 288, 288), 'crop_pct': 0.95, 'test_crop_pct': 1.0,
'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
def _cfgpyc(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'license': 'mit', 'origin_url': 'https://github.com/facebookresearch/pycls', **kwargs
}
def _cfgtv2(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.965, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'license': 'bsd-3-clause', 'origin_url': 'https://github.com/pytorch/vision', **kwargs
}
default_cfgs = generate_default_cfgs({
# timm trained models
'regnety_032.ra_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'),
'regnety_040.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth'),
'regnety_064.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth'),
'regnety_080.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth'),
'regnety_120.sw_in12k_ft_in1k': _cfg(hf_hub_id='timm/'),
'regnety_160.sw_in12k_ft_in1k': _cfg(hf_hub_id='timm/'),
'regnety_160.lion_in12k_ft_in1k': _cfg(hf_hub_id='timm/'),
# timm in12k pretrain
'regnety_120.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'regnety_160.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# timm custom arch (v and z guess) + trained models
'regnety_040_sgn.untrained': _cfg(url=''),
'regnetv_040.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth',
first_conv='stem'),
'regnetv_064.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth',
first_conv='stem'),
'regnetz_005.untrained': _cfg(url=''),
'regnetz_040.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
'regnetz_040_h.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
# used in DeiT for distillation (from Facebook DeiT GitHub repository)
'regnety_160.deit_in1k': _cfg(
hf_hub_id='timm/', url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'),
'regnetx_004_tv.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth'),
'regnetx_008.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth'),
'regnetx_016.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth'),
'regnetx_032.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth'),
'regnetx_080.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth'),
'regnetx_160.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth'),
'regnetx_320.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth'),
'regnety_004.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth'),
'regnety_008_tv.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth'),
'regnety_016.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth'),
'regnety_032.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth'),
'regnety_080_tv.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth'),
'regnety_160.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth'),
'regnety_320.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth'),
'regnety_160.swag_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth', license='cc-by-nc-4.0',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_320.swag_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth', license='cc-by-nc-4.0',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_1280.swag_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth', license='cc-by-nc-4.0',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_160.swag_lc_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth', license='cc-by-nc-4.0'),
'regnety_320.swag_lc_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth', license='cc-by-nc-4.0'),
'regnety_1280.swag_lc_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth', license='cc-by-nc-4.0'),
'regnety_320.seer_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
license='other', origin_url='https://github.com/facebookresearch/vissl',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_640.seer_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
license='other', origin_url='https://github.com/facebookresearch/vissl',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_1280.seer_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
license='other', origin_url='https://github.com/facebookresearch/vissl',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_2560.seer_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
license='other', origin_url='https://github.com/facebookresearch/vissl',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet256_finetuned_in1k_model_final_checkpoint_phase38.torch',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_320.seer': _cfgtv2(
hf_hub_id='timm/',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch',
num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'),
'regnety_640.seer': _cfgtv2(
hf_hub_id='timm/',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch',
num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'),
'regnety_1280.seer': _cfgtv2(
hf_hub_id='timm/',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch',
num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'),
# FIXME invalid weight <-> model match, mistake on their end
#'regnety_2560.seer': _cfgtv2(
# url='https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_cosine_rg256gf_noBNhead_wd1e5_fairstore_bs16_node64_sinkhorn10_proto16k_apex_syncBN64_warmup8k/model_final_checkpoint_phase0.torch',
# num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'),
'regnetx_002.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_004.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_006.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_008.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_016.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_032.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_040.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_064.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_080.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_120.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_160.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_320.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_002.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_004.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_006.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_008.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_016.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_032.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_040.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_064.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_080.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_120.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_160.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_320.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
})
@register_model @register_model
def regnetx_002(pretrained=False, **kwargs): def regnetx_002(pretrained=False, **kwargs):
"""RegNetX-200MF""" """RegNetX-200MF"""
@ -644,6 +904,12 @@ def regnetx_004(pretrained=False, **kwargs):
return _create_regnet('regnetx_004', pretrained, **kwargs) return _create_regnet('regnetx_004', pretrained, **kwargs)
@register_model
def regnetx_004_tv(pretrained=False, **kwargs):
"""RegNetX-400MF w/ torchvision group rounding"""
return _create_regnet('regnetx_004_tv', pretrained, **kwargs)
@register_model @register_model
def regnetx_006(pretrained=False, **kwargs): def regnetx_006(pretrained=False, **kwargs):
"""RegNetX-600MF""" """RegNetX-600MF"""
@ -728,6 +994,12 @@ def regnety_008(pretrained=False, **kwargs):
return _create_regnet('regnety_008', pretrained, **kwargs) return _create_regnet('regnety_008', pretrained, **kwargs)
@register_model
def regnety_008_tv(pretrained=False, **kwargs):
"""RegNetY-800MF w/ torchvision group rounding"""
return _create_regnet('regnety_008_tv', pretrained, **kwargs)
@register_model @register_model
def regnety_016(pretrained=False, **kwargs): def regnety_016(pretrained=False, **kwargs):
"""RegNetY-1.6GF""" """RegNetY-1.6GF"""
@ -758,6 +1030,12 @@ def regnety_080(pretrained=False, **kwargs):
return _create_regnet('regnety_080', pretrained, **kwargs) return _create_regnet('regnety_080', pretrained, **kwargs)
@register_model
def regnety_080_tv(pretrained=False, **kwargs):
"""RegNetY-8.0GF w/ torchvision group rounding"""
return _create_regnet('regnety_080_tv', pretrained, **kwargs)
@register_model @register_model
def regnety_120(pretrained=False, **kwargs): def regnety_120(pretrained=False, **kwargs):
"""RegNetY-12GF""" """RegNetY-12GF"""
@ -795,20 +1073,20 @@ def regnety_2560(pretrained=False, **kwargs):
@register_model @register_model
def regnety_040s_gn(pretrained=False, **kwargs): def regnety_040_sgn(pretrained=False, **kwargs):
"""RegNetY-4.0GF w/ GroupNorm """ """RegNetY-4.0GF w/ GroupNorm """
return _create_regnet('regnety_040s_gn', pretrained, **kwargs) return _create_regnet('regnety_040_sgn', pretrained, **kwargs)
@register_model @register_model
def regnetv_040(pretrained=False, **kwargs): def regnetv_040(pretrained=False, **kwargs):
"""""" """RegNetV-4.0GF (pre-activation)"""
return _create_regnet('regnetv_040', pretrained, **kwargs) return _create_regnet('regnetv_040', pretrained, **kwargs)
@register_model @register_model
def regnetv_064(pretrained=False, **kwargs): def regnetv_064(pretrained=False, **kwargs):
"""""" """RegNetV-6.4GF (pre-activation)"""
return _create_regnet('regnetv_064', pretrained, **kwargs) return _create_regnet('regnetv_064', pretrained, **kwargs)
@ -831,9 +1109,14 @@ def regnetz_040(pretrained=False, **kwargs):
@register_model @register_model
def regnetz_040h(pretrained=False, **kwargs): def regnetz_040_h(pretrained=False, **kwargs):
"""RegNetZ-4.0GF """RegNetZ-4.0GF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper. but it's not clear it is equivalent to paper model as not detailed in the paper.
""" """
return _create_regnet('regnetz_040h', pretrained, zero_init_last=False, **kwargs) return _create_regnet('regnetz_040_h', pretrained, zero_init_last=False, **kwargs)
register_model_deprecations(__name__, {
'regnetz_040h': 'regnetz_040_h',
})