regnet.py multi-weight conversion, new ImageNet-12k pretrain/ft from timm for y_120 and y_160, also new tv v2, swag, & seer weights for push to Hf hub.

pull/1736/head
Ross Wightman 2023-03-21 15:51:49 -07:00
parent c78319adce
commit e7ef8335bf
1 changed files with 460 additions and 177 deletions

View File

@ -1,16 +1,26 @@
"""RegNet
"""RegNet X, Y, Z, and more
Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678
Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
Paper: `Fast and Accurate Model Scaling` - https://arxiv.org/abs/2103.06877
Original Impl: None
Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here)
and cleaned up with more descriptive variable names.
Weights from original impl have been modified
Weights from original pycls impl have been modified:
* first layer from BGR -> RGB as most PyTorch models are
* removed training specific dict entries from checkpoints and keep model state_dict only
* remap names to match the ones here
Supports weight loading from torchvision and classy-vision (incl VISSL SEER)
A number of custom timm model definitions additions including:
* stochastic depth, gradient checkpointing, layer-decay, configurable dilation
* a pre-activation 'V' variant
* only known RegNet-Z model definitions with pretrained weights
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
@ -24,10 +34,10 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this
@ -41,6 +51,7 @@ class RegNetCfg:
group_size: int = 24
bottle_ratio: float = 1.
se_ratio: float = 0.
group_min_ratio: float = 0.
stem_width: int = 32
downsample: Optional[str] = 'conv1x1'
linear_out: bool = False
@ -50,178 +61,79 @@ class RegNetCfg:
norm_layer: Union[str, Callable] = 'batchnorm'
# Model FLOPS = three trailing digits * 10^8
model_cfgs = dict(
# RegNet-X
regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13),
regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22),
regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16),
regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16),
regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18),
regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25),
regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23),
regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17),
regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23),
regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19),
regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22),
regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23),
# RegNet-Y
regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25),
regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25),
regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25),
regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25),
regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25),
regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25),
regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25),
regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25),
regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25),
regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25),
regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25),
regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25),
regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25),
regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25),
regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25),
# Experimental
regnety_040s_gn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# regnetv = 'preact regnet y'
regnetv_040=RegNetCfg(
depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'),
regnetv_064=RegNetCfg(
depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu',
downsample='avg'),
# RegNet-Z (unverified)
regnetz_005=RegNetCfg(
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1024, act_layer='silu',
),
regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=0, act_layer='silu',
),
regnetz_040h=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1536, act_layer='silu',
),
)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
regnetx_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'),
regnetx_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'),
regnetx_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'),
regnetx_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'),
regnetx_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'),
regnetx_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'),
regnetx_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'),
regnetx_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'),
regnetx_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'),
regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'),
regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'),
regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'),
regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'),
regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'),
regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'),
regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'),
regnety_032=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_080=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
regnety_160=_cfg(
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_320=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
),
regnety_640=_cfg(url=''),
regnety_1280=_cfg(url=''),
regnety_2560=_cfg(url=''),
regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth',
first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)),
regnetv_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth',
first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)),
regnetz_005=_cfg(url=''),
regnetz_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
regnetz_040h=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
)
def quantize_float(f, q):
"""Converts a float to closest non-zero int divisible by q."""
"""Converts a float to the closest non-zero int divisible by q."""
return int(round(f / q) * q)
def adjust_widths_groups_comp(widths, bottle_ratios, groups):
def adjust_widths_groups_comp(widths, bottle_ratios, groups, min_ratio=0.):
"""Adjusts the compatibility of widths and groups."""
bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)]
groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)]
bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)]
if min_ratio:
# torchvision uses a different rounding scheme for ensuring bottleneck widths divisible by group widths
bottleneck_widths = [make_divisible(w_bot, g, min_ratio) for w_bot, g in zip(bottleneck_widths, groups)]
else:
bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)]
widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)]
return widths, groups
def generate_regnet(width_slope, width_initial, width_mult, depth, group_size, q=8):
def generate_regnet(width_slope, width_initial, width_mult, depth, group_size, quant=8):
"""Generates per block widths from RegNet parameters."""
assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0
assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % quant == 0
# TODO dWr scaling?
# depth = int(depth * (scale ** 0.1))
# width_scale = scale ** 0.4 # dWr scale, exp 0.8 / 2, applied to both group and layer widths
widths_cont = np.arange(depth) * width_slope + width_initial
width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult))
widths = width_initial * np.power(width_mult, width_exps)
widths = np.round(np.divide(widths, q)) * q
widths = np.round(np.divide(width_initial * np.power(width_mult, width_exps), quant)) * quant
num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1
groups = np.array([group_size for _ in range(num_stages)])
return widths.astype(int).tolist(), num_stages, groups.astype(int).tolist()
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False):
def downsample_conv(
in_chs,
out_chs,
kernel_size=1,
stride=1,
dilation=1,
norm_layer=None,
preact=False,
):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1
if preact:
return create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation)
return create_conv2d(
in_chs,
out_chs,
kernel_size,
stride=stride,
dilation=dilation,
)
else:
return ConvNormAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False)
in_chs,
out_chs,
kernel_size,
stride=stride,
dilation=dilation,
norm_layer=norm_layer,
apply_act=False,
)
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False):
def downsample_avg(
in_chs,
out_chs,
kernel_size=1,
stride=1,
dilation=1,
norm_layer=None,
preact=False,
):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
@ -290,8 +202,15 @@ class Bottleneck(nn.Module):
cargs = dict(act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs)
self.conv2 = ConvNormAct(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0],
groups=groups, drop_layer=drop_block, **cargs)
bottleneck_chs,
bottleneck_chs,
kernel_size=3,
stride=stride,
dilation=dilation[0],
groups=groups,
drop_layer=drop_block,
**cargs,
)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
@ -299,7 +218,15 @@ class Bottleneck(nn.Module):
self.se = nn.Identity()
self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs)
self.act3 = nn.Identity() if linear_out else act_layer()
self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, norm_layer=norm_layer)
self.downsample = create_shortcut(
downsample,
in_chs,
out_chs,
kernel_size=1,
stride=stride,
dilation=dilation,
norm_layer=norm_layer,
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self):
@ -351,7 +278,13 @@ class PreBottleneck(nn.Module):
self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1)
self.norm2 = norm_act_layer(bottleneck_chs)
self.conv2 = create_conv2d(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], groups=groups)
bottleneck_chs,
bottleneck_chs,
kernel_size=3,
stride=stride,
dilation=dilation[0],
groups=groups,
)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
@ -359,7 +292,15 @@ class PreBottleneck(nn.Module):
self.se = nn.Identity()
self.norm3 = norm_act_layer(bottleneck_chs)
self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1)
self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, preact=True)
self.downsample = create_shortcut(
downsample,
in_chs,
out_chs,
kernel_size=1,
stride=stride,
dilation=dilation,
preact=True,
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self):
@ -406,7 +347,8 @@ class RegStage(nn.Module):
dpr = drop_path_rates[i] if drop_path_rates is not None else 0.
name = "b{}".format(i + 1)
self.add_module(
name, block_fn(
name,
block_fn(
block_in_chs,
out_chs,
stride=block_stride,
@ -477,12 +419,23 @@ class RegNet(nn.Module):
prev_width = stem_width
curr_stride = 2
per_stage_args, common_args = self._get_stage_args(
cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
cfg,
output_stride=output_stride,
drop_path_rate=drop_path_rate,
)
assert len(per_stage_args) == 4
block_fn = PreBottleneck if cfg.preact else Bottleneck
for i, stage_args in enumerate(per_stage_args):
stage_name = "s{}".format(i + 1)
self.add_module(stage_name, RegStage(in_chs=prev_width, block_fn=block_fn, **stage_args, **common_args))
self.add_module(
stage_name,
RegStage(
in_chs=prev_width,
block_fn=block_fn,
**stage_args,
**common_args,
)
)
prev_width = stage_args['out_chs']
curr_stride *= stage_args['stride']
self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
@ -496,7 +449,11 @@ class RegNet(nn.Module):
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width
self.head = ClassifierHead(
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=self.num_features,
num_classes=num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@ -523,11 +480,13 @@ class RegNet(nn.Module):
stage_dpr = np.split(np.linspace(0, drop_path_rate, sum(stage_depths)), np.cumsum(stage_depths[:-1]))
# Adjust the compatibility of ws and gws
stage_widths, stage_gs = adjust_widths_groups_comp(stage_widths, stage_br, stage_gs)
stage_widths, stage_gs = adjust_widths_groups_comp(
stage_widths, stage_br, stage_gs, min_ratio=cfg.group_min_ratio)
arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
per_stage_args = [
dict(zip(arg_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)]
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)
]
common_args = dict(
downsample=cfg.downsample,
se_ratio=cfg.se_ratio,
@ -554,7 +513,7 @@ class RegNet(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -590,7 +549,25 @@ def _init_weights(module, name='', zero_init_last=False):
def _filter_fn(state_dict):
state_dict = state_dict.get('model', state_dict)
replaces = [
('f.a.0', 'conv1.conv'),
('f.a.1', 'conv1.bn'),
('f.b.0', 'conv2.conv'),
('f.b.1', 'conv2.bn'),
('f.final_bn', 'conv3.bn'),
('f.se.excitation.0', 'se.fc1'),
('f.se.excitation.2', 'se.fc2'),
('f.se', 'se'),
('f.c.0', 'conv3.conv'),
('f.c.1', 'conv3.bn'),
('f.c', 'conv3.conv'),
('proj.0', 'downsample.conv'),
('proj.1', 'downsample.bn'),
('proj', 'downsample.conv'),
]
if 'classy_state_dict' in state_dict:
# classy-vision & vissl (SEER) weights
import re
state_dict = state_dict['classy_state_dict']['base_model']['model']
out = {}
@ -601,15 +578,8 @@ def _filter_fn(state_dict):
r'^_feature_blocks.res\d.block(\d)-(\d+)',
lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k)
k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k)
k = k.replace('proj', 'downsample.conv')
k = k.replace('f.a.0', 'conv1.conv')
k = k.replace('f.a.1', 'conv1.bn')
k = k.replace('f.b.0', 'conv2.conv')
k = k.replace('f.b.1', 'conv2.bn')
k = k.replace('f.c', 'conv3.conv')
k = k.replace('f.final_bn', 'conv3.bn')
k = k.replace('f.se.excitation.0', 'se.fc1')
k = k.replace('f.se.excitation.2', 'se.fc2')
for s, r in replaces:
k = k.replace(s, r)
out[k] = v
for k, v in state_dict['heads'].items():
if 'projection_head' in k or 'prototypes' in k:
@ -617,13 +587,89 @@ def _filter_fn(state_dict):
k = k.replace('0.clf.0', 'head.fc')
out[k] = v
return out
if 'model' in state_dict:
# For DeiT trained regnety_160 pretraiend model
state_dict = state_dict['model']
if 'stem.0.weight' in state_dict:
# torchvision weights
import re
out = {}
for k, v in state_dict.items():
k = k.replace('stem.0', 'stem.conv')
k = k.replace('stem.1', 'stem.bn')
k = re.sub(
r'trunk_output.block(\d)\.block(\d+)\-(\d+)',
lambda x: f's{int(x.group(1))}.b{int(x.group(3)) + 1}', k)
for s, r in replaces:
k = k.replace(s, r)
k = k.replace('fc.', 'head.fc.')
out[k] = v
return out
return state_dict
# Model FLOPS = three trailing digits * 10^8
model_cfgs = dict(
# RegNet-X
regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13),
regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22),
regnetx_004_tv=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22, group_min_ratio=0.9),
regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16),
regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16),
regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18),
regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25),
regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23),
regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17),
regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23),
regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19),
regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22),
regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23),
# RegNet-Y
regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25),
regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25),
regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25),
regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25),
regnety_008_tv=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25, group_min_ratio=0.9),
regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25),
regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25),
regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25),
regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25),
regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25),
regnety_080_tv=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25, group_min_ratio=0.9),
regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25),
regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25),
regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25),
regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25),
regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25),
regnety_2560=RegNetCfg(w0=640, wa=230.83, wm=2.53, group_size=373, depth=27, se_ratio=0.25),
#regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25),
# Experimental
regnety_040_sgn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# regnetv = 'preact regnet y'
regnetv_040=RegNetCfg(
depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'),
regnetv_064=RegNetCfg(
depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu',
downsample='avg'),
# RegNet-Z (unverified)
regnetz_005=RegNetCfg(
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1024, act_layer='silu',
),
regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=0, act_layer='silu',
),
regnetz_040_h=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1536, act_layer='silu',
),
)
def _create_regnet(variant, pretrained, **kwargs):
return build_model_with_cfg(
RegNet, variant, pretrained,
@ -632,6 +678,220 @@ def _create_regnet(variant, pretrained, **kwargs):
**kwargs)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'test_input_size': (3, 288, 288), 'crop_pct': 0.95, 'test_crop_pct': 1.0,
'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
def _cfgpyc(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'license': 'mit', 'origin_url': 'https://github.com/facebookresearch/pycls', **kwargs
}
def _cfgtv2(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.965, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'license': 'bsd-3-clause', 'origin_url': 'https://github.com/pytorch/vision', **kwargs
}
default_cfgs = generate_default_cfgs({
# timm trained models
'regnety_032.ra_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'),
'regnety_040.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth'),
'regnety_064.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth'),
'regnety_080.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth'),
'regnety_120.sw_in12k_ft_in1k': _cfg(hf_hub_id='timm/'),
'regnety_160.sw_in12k_ft_in1k': _cfg(hf_hub_id='timm/'),
'regnety_160.lion_in12k_ft_in1k': _cfg(hf_hub_id='timm/'),
# timm in12k pretrain
'regnety_120.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'regnety_160.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# timm custom arch (v and z guess) + trained models
'regnety_040_sgn.untrained': _cfg(url=''),
'regnetv_040.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth',
first_conv='stem'),
'regnetv_064.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth',
first_conv='stem'),
'regnetz_005.untrained': _cfg(url=''),
'regnetz_040.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
'regnetz_040_h.ra3_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
# used in DeiT for distillation (from Facebook DeiT GitHub repository)
'regnety_160.deit_in1k': _cfg(
hf_hub_id='timm/', url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'),
'regnetx_004_tv.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth'),
'regnetx_008.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth'),
'regnetx_016.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth'),
'regnetx_032.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth'),
'regnetx_080.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth'),
'regnetx_160.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth'),
'regnetx_320.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth'),
'regnety_004.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth'),
'regnety_008_tv.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth'),
'regnety_016.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth'),
'regnety_032.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth'),
'regnety_080_tv.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth'),
'regnety_160.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth'),
'regnety_320.tv2_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth'),
'regnety_160.swag_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth', license='cc-by-nc-4.0',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_320.swag_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth', license='cc-by-nc-4.0',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_1280.swag_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth', license='cc-by-nc-4.0',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_160.swag_lc_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth', license='cc-by-nc-4.0'),
'regnety_320.swag_lc_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth', license='cc-by-nc-4.0'),
'regnety_1280.swag_lc_in1k': _cfgtv2(
hf_hub_id='timm/',
url='https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth', license='cc-by-nc-4.0'),
'regnety_320.seer_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
license='other', origin_url='https://github.com/facebookresearch/vissl',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_640.seer_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
license='other', origin_url='https://github.com/facebookresearch/vissl',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_1280.seer_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
license='other', origin_url='https://github.com/facebookresearch/vissl',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_2560.seer_ft_in1k': _cfgtv2(
hf_hub_id='timm/',
license='other', origin_url='https://github.com/facebookresearch/vissl',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet256_finetuned_in1k_model_final_checkpoint_phase38.torch',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'regnety_320.seer': _cfgtv2(
hf_hub_id='timm/',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch',
num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'),
'regnety_640.seer': _cfgtv2(
hf_hub_id='timm/',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch',
num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'),
'regnety_1280.seer': _cfgtv2(
hf_hub_id='timm/',
url='https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch',
num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'),
# FIXME invalid weight <-> model match, mistake on their end
#'regnety_2560.seer': _cfgtv2(
# url='https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_cosine_rg256gf_noBNhead_wd1e5_fairstore_bs16_node64_sinkhorn10_proto16k_apex_syncBN64_warmup8k/model_final_checkpoint_phase0.torch',
# num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'),
'regnetx_002.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_004.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_006.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_008.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_016.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_032.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_040.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_064.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_080.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_120.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_160.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnetx_320.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_002.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_004.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_006.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_008.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_016.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_032.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_040.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_064.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_080.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_120.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_160.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
'regnety_320.pycls_in1k': _cfgpyc(hf_hub_id='timm/'),
})
@register_model
def regnetx_002(pretrained=False, **kwargs):
"""RegNetX-200MF"""
@ -644,6 +904,12 @@ def regnetx_004(pretrained=False, **kwargs):
return _create_regnet('regnetx_004', pretrained, **kwargs)
@register_model
def regnetx_004_tv(pretrained=False, **kwargs):
"""RegNetX-400MF w/ torchvision group rounding"""
return _create_regnet('regnetx_004_tv', pretrained, **kwargs)
@register_model
def regnetx_006(pretrained=False, **kwargs):
"""RegNetX-600MF"""
@ -728,6 +994,12 @@ def regnety_008(pretrained=False, **kwargs):
return _create_regnet('regnety_008', pretrained, **kwargs)
@register_model
def regnety_008_tv(pretrained=False, **kwargs):
"""RegNetY-800MF w/ torchvision group rounding"""
return _create_regnet('regnety_008_tv', pretrained, **kwargs)
@register_model
def regnety_016(pretrained=False, **kwargs):
"""RegNetY-1.6GF"""
@ -758,6 +1030,12 @@ def regnety_080(pretrained=False, **kwargs):
return _create_regnet('regnety_080', pretrained, **kwargs)
@register_model
def regnety_080_tv(pretrained=False, **kwargs):
"""RegNetY-8.0GF w/ torchvision group rounding"""
return _create_regnet('regnety_080_tv', pretrained, **kwargs)
@register_model
def regnety_120(pretrained=False, **kwargs):
"""RegNetY-12GF"""
@ -795,20 +1073,20 @@ def regnety_2560(pretrained=False, **kwargs):
@register_model
def regnety_040s_gn(pretrained=False, **kwargs):
def regnety_040_sgn(pretrained=False, **kwargs):
"""RegNetY-4.0GF w/ GroupNorm """
return _create_regnet('regnety_040s_gn', pretrained, **kwargs)
return _create_regnet('regnety_040_sgn', pretrained, **kwargs)
@register_model
def regnetv_040(pretrained=False, **kwargs):
""""""
"""RegNetV-4.0GF (pre-activation)"""
return _create_regnet('regnetv_040', pretrained, **kwargs)
@register_model
def regnetv_064(pretrained=False, **kwargs):
""""""
"""RegNetV-6.4GF (pre-activation)"""
return _create_regnet('regnetv_064', pretrained, **kwargs)
@ -831,9 +1109,14 @@ def regnetz_040(pretrained=False, **kwargs):
@register_model
def regnetz_040h(pretrained=False, **kwargs):
def regnetz_040_h(pretrained=False, **kwargs):
"""RegNetZ-4.0GF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_040h', pretrained, zero_init_last=False, **kwargs)
return _create_regnet('regnetz_040_h', pretrained, zero_init_last=False, **kwargs)
register_model_deprecations(__name__, {
'regnetz_040h': 'regnetz_040_h',
})