diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 63c9b57f..250ba2e4 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -1,16 +1,26 @@ -"""RegNet +"""RegNet X, Y, Z, and more Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678 Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py +Paper: `Fast and Accurate Model Scaling` - https://arxiv.org/abs/2103.06877 +Original Impl: None + Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here) and cleaned up with more descriptive variable names. -Weights from original impl have been modified +Weights from original pycls impl have been modified: * first layer from BGR -> RGB as most PyTorch models are * removed training specific dict entries from checkpoints and keep model state_dict only * remap names to match the ones here +Supports weight loading from torchvision and classy-vision (incl VISSL SEER) + +A number of custom timm model definitions additions including: +* stochastic depth, gradient checkpointing, layer-decay, configurable dilation +* a pre-activation 'V' variant +* only known RegNet-Z model definitions with pretrained weights + Hacked together by / Copyright 2020 Ross Wightman """ import math @@ -24,10 +34,10 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct -from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d +from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq, named_apply -from ._registry import register_model +from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['RegNet', 'RegNetCfg'] # model_registry will add each entrypoint fn to this @@ -41,6 +51,7 @@ class RegNetCfg: group_size: int = 24 bottle_ratio: float = 1. se_ratio: float = 0. + group_min_ratio: float = 0. stem_width: int = 32 downsample: Optional[str] = 'conv1x1' linear_out: bool = False @@ -50,178 +61,79 @@ class RegNetCfg: norm_layer: Union[str, Callable] = 'batchnorm' -# Model FLOPS = three trailing digits * 10^8 -model_cfgs = dict( - # RegNet-X - regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13), - regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22), - regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16), - regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16), - regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18), - regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25), - regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23), - regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17), - regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23), - regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19), - regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22), - regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23), - - # RegNet-Y - regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25), - regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25), - regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25), - regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25), - regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25), - regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25), - regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25), - regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25), - regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25), - regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25), - regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25), - regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25), - regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25), - regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25), - regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25), - - # Experimental - regnety_040s_gn=RegNetCfg( - w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25, - act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)), - - # regnetv = 'preact regnet y' - regnetv_040=RegNetCfg( - depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'), - regnetv_064=RegNetCfg( - depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu', - downsample='avg'), - - # RegNet-Z (unverified) - regnetz_005=RegNetCfg( - depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25, - downsample=None, linear_out=True, num_features=1024, act_layer='silu', - ), - regnetz_040=RegNetCfg( - depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, - downsample=None, linear_out=True, num_features=0, act_layer='silu', - ), - regnetz_040h=RegNetCfg( - depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, - downsample=None, linear_out=True, num_features=1536, act_layer='silu', - ), -) - - -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv', 'classifier': 'head.fc', - **kwargs - } - - -default_cfgs = dict( - regnetx_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'), - regnetx_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'), - regnetx_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'), - regnetx_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'), - regnetx_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'), - regnetx_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'), - regnetx_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'), - regnetx_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'), - regnetx_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'), - regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'), - regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'), - regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'), - - regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'), - regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'), - regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), - regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'), - regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'), - regnety_032=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth', - crop_pct=1.0, test_input_size=(3, 288, 288)), - regnety_040=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth', - crop_pct=1.0, test_input_size=(3, 288, 288)), - regnety_064=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth', - crop_pct=1.0, test_input_size=(3, 288, 288)), - regnety_080=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth', - crop_pct=1.0, test_input_size=(3, 288, 288)), - regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), - regnety_160=_cfg( - url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository - crop_pct=1.0, test_input_size=(3, 288, 288)), - regnety_320=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth' - ), - regnety_640=_cfg(url=''), - regnety_1280=_cfg(url=''), - regnety_2560=_cfg(url=''), - - regnety_040s_gn=_cfg(url=''), - regnetv_040=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth', - first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)), - regnetv_064=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth', - first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)), - - regnetz_005=_cfg(url=''), - regnetz_040=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth', - input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)), - regnetz_040h=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth', - input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)), -) - - def quantize_float(f, q): - """Converts a float to closest non-zero int divisible by q.""" + """Converts a float to the closest non-zero int divisible by q.""" return int(round(f / q) * q) -def adjust_widths_groups_comp(widths, bottle_ratios, groups): +def adjust_widths_groups_comp(widths, bottle_ratios, groups, min_ratio=0.): """Adjusts the compatibility of widths and groups.""" bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)] groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] - bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)] + if min_ratio: + # torchvision uses a different rounding scheme for ensuring bottleneck widths divisible by group widths + bottleneck_widths = [make_divisible(w_bot, g, min_ratio) for w_bot, g in zip(bottleneck_widths, groups)] + else: + bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)] widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)] return widths, groups -def generate_regnet(width_slope, width_initial, width_mult, depth, group_size, q=8): +def generate_regnet(width_slope, width_initial, width_mult, depth, group_size, quant=8): """Generates per block widths from RegNet parameters.""" - assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0 + assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % quant == 0 # TODO dWr scaling? # depth = int(depth * (scale ** 0.1)) # width_scale = scale ** 0.4 # dWr scale, exp 0.8 / 2, applied to both group and layer widths widths_cont = np.arange(depth) * width_slope + width_initial width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult)) - widths = width_initial * np.power(width_mult, width_exps) - widths = np.round(np.divide(widths, q)) * q + widths = np.round(np.divide(width_initial * np.power(width_mult, width_exps), quant)) * quant num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1 groups = np.array([group_size for _ in range(num_stages)]) return widths.astype(int).tolist(), num_stages, groups.astype(int).tolist() -def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False): +def downsample_conv( + in_chs, + out_chs, + kernel_size=1, + stride=1, + dilation=1, + norm_layer=None, + preact=False, +): norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size dilation = dilation if kernel_size > 1 else 1 if preact: - return create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation) + return create_conv2d( + in_chs, + out_chs, + kernel_size, + stride=stride, + dilation=dilation, + ) else: return ConvNormAct( - in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False) + in_chs, + out_chs, + kernel_size, + stride=stride, + dilation=dilation, + norm_layer=norm_layer, + apply_act=False, + ) -def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False): +def downsample_avg( + in_chs, + out_chs, + kernel_size=1, + stride=1, + dilation=1, + norm_layer=None, + preact=False, +): """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 @@ -290,8 +202,15 @@ class Bottleneck(nn.Module): cargs = dict(act_layer=act_layer, norm_layer=norm_layer) self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) self.conv2 = ConvNormAct( - bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], - groups=groups, drop_layer=drop_block, **cargs) + bottleneck_chs, + bottleneck_chs, + kernel_size=3, + stride=stride, + dilation=dilation[0], + groups=groups, + drop_layer=drop_block, + **cargs, + ) if se_ratio: se_channels = int(round(in_chs * se_ratio)) self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) @@ -299,7 +218,15 @@ class Bottleneck(nn.Module): self.se = nn.Identity() self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs) self.act3 = nn.Identity() if linear_out else act_layer() - self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, norm_layer=norm_layer) + self.downsample = create_shortcut( + downsample, + in_chs, + out_chs, + kernel_size=1, + stride=stride, + dilation=dilation, + norm_layer=norm_layer, + ) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() def zero_init_last(self): @@ -351,7 +278,13 @@ class PreBottleneck(nn.Module): self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1) self.norm2 = norm_act_layer(bottleneck_chs) self.conv2 = create_conv2d( - bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], groups=groups) + bottleneck_chs, + bottleneck_chs, + kernel_size=3, + stride=stride, + dilation=dilation[0], + groups=groups, + ) if se_ratio: se_channels = int(round(in_chs * se_ratio)) self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) @@ -359,7 +292,15 @@ class PreBottleneck(nn.Module): self.se = nn.Identity() self.norm3 = norm_act_layer(bottleneck_chs) self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1) - self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, preact=True) + self.downsample = create_shortcut( + downsample, + in_chs, + out_chs, + kernel_size=1, + stride=stride, + dilation=dilation, + preact=True, + ) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() def zero_init_last(self): @@ -406,7 +347,8 @@ class RegStage(nn.Module): dpr = drop_path_rates[i] if drop_path_rates is not None else 0. name = "b{}".format(i + 1) self.add_module( - name, block_fn( + name, + block_fn( block_in_chs, out_chs, stride=block_stride, @@ -477,12 +419,23 @@ class RegNet(nn.Module): prev_width = stem_width curr_stride = 2 per_stage_args, common_args = self._get_stage_args( - cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) + cfg, + output_stride=output_stride, + drop_path_rate=drop_path_rate, + ) assert len(per_stage_args) == 4 block_fn = PreBottleneck if cfg.preact else Bottleneck for i, stage_args in enumerate(per_stage_args): stage_name = "s{}".format(i + 1) - self.add_module(stage_name, RegStage(in_chs=prev_width, block_fn=block_fn, **stage_args, **common_args)) + self.add_module( + stage_name, + RegStage( + in_chs=prev_width, + block_fn=block_fn, + **stage_args, + **common_args, + ) + ) prev_width = stage_args['out_chs'] curr_stride *= stage_args['stride'] self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] @@ -496,7 +449,11 @@ class RegNet(nn.Module): self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.num_features = prev_width self.head = ClassifierHead( - in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=self.num_features, + num_classes=num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -523,11 +480,13 @@ class RegNet(nn.Module): stage_dpr = np.split(np.linspace(0, drop_path_rate, sum(stage_depths)), np.cumsum(stage_depths[:-1])) # Adjust the compatibility of ws and gws - stage_widths, stage_gs = adjust_widths_groups_comp(stage_widths, stage_br, stage_gs) + stage_widths, stage_gs = adjust_widths_groups_comp( + stage_widths, stage_br, stage_gs, min_ratio=cfg.group_min_ratio) arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates'] per_stage_args = [ dict(zip(arg_names, params)) for params in - zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)] + zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr) + ] common_args = dict( downsample=cfg.downsample, se_ratio=cfg.se_ratio, @@ -554,7 +513,7 @@ class RegNet(nn.Module): return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): x = self.stem(x) @@ -590,7 +549,25 @@ def _init_weights(module, name='', zero_init_last=False): def _filter_fn(state_dict): + state_dict = state_dict.get('model', state_dict) + replaces = [ + ('f.a.0', 'conv1.conv'), + ('f.a.1', 'conv1.bn'), + ('f.b.0', 'conv2.conv'), + ('f.b.1', 'conv2.bn'), + ('f.final_bn', 'conv3.bn'), + ('f.se.excitation.0', 'se.fc1'), + ('f.se.excitation.2', 'se.fc2'), + ('f.se', 'se'), + ('f.c.0', 'conv3.conv'), + ('f.c.1', 'conv3.bn'), + ('f.c', 'conv3.conv'), + ('proj.0', 'downsample.conv'), + ('proj.1', 'downsample.bn'), + ('proj', 'downsample.conv'), + ] if 'classy_state_dict' in state_dict: + # classy-vision & vissl (SEER) weights import re state_dict = state_dict['classy_state_dict']['base_model']['model'] out = {} @@ -601,15 +578,8 @@ def _filter_fn(state_dict): r'^_feature_blocks.res\d.block(\d)-(\d+)', lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k) k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k) - k = k.replace('proj', 'downsample.conv') - k = k.replace('f.a.0', 'conv1.conv') - k = k.replace('f.a.1', 'conv1.bn') - k = k.replace('f.b.0', 'conv2.conv') - k = k.replace('f.b.1', 'conv2.bn') - k = k.replace('f.c', 'conv3.conv') - k = k.replace('f.final_bn', 'conv3.bn') - k = k.replace('f.se.excitation.0', 'se.fc1') - k = k.replace('f.se.excitation.2', 'se.fc2') + for s, r in replaces: + k = k.replace(s, r) out[k] = v for k, v in state_dict['heads'].items(): if 'projection_head' in k or 'prototypes' in k: @@ -617,13 +587,89 @@ def _filter_fn(state_dict): k = k.replace('0.clf.0', 'head.fc') out[k] = v return out - - if 'model' in state_dict: - # For DeiT trained regnety_160 pretraiend model - state_dict = state_dict['model'] + if 'stem.0.weight' in state_dict: + # torchvision weights + import re + out = {} + for k, v in state_dict.items(): + k = k.replace('stem.0', 'stem.conv') + k = k.replace('stem.1', 'stem.bn') + k = re.sub( + r'trunk_output.block(\d)\.block(\d+)\-(\d+)', + lambda x: f's{int(x.group(1))}.b{int(x.group(3)) + 1}', k) + for s, r in replaces: + k = k.replace(s, r) + k = k.replace('fc.', 'head.fc.') + out[k] = v + return out return state_dict +# Model FLOPS = three trailing digits * 10^8 +model_cfgs = dict( + # RegNet-X + regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13), + regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22), + regnetx_004_tv=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22, group_min_ratio=0.9), + regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16), + regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16), + regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18), + regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25), + regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23), + regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17), + regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23), + regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19), + regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22), + regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23), + + # RegNet-Y + regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25), + regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25), + regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25), + regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25), + regnety_008_tv=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25, group_min_ratio=0.9), + regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25), + regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25), + regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25), + regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25), + regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25), + regnety_080_tv=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25, group_min_ratio=0.9), + regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25), + regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25), + regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25), + regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25), + regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25), + regnety_2560=RegNetCfg(w0=640, wa=230.83, wm=2.53, group_size=373, depth=27, se_ratio=0.25), + #regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25), + + # Experimental + regnety_040_sgn=RegNetCfg( + w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25, + act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)), + + # regnetv = 'preact regnet y' + regnetv_040=RegNetCfg( + depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'), + regnetv_064=RegNetCfg( + depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu', + downsample='avg'), + + # RegNet-Z (unverified) + regnetz_005=RegNetCfg( + depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, num_features=1024, act_layer='silu', + ), + regnetz_040=RegNetCfg( + depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, num_features=0, act_layer='silu', + ), + regnetz_040_h=RegNetCfg( + depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, num_features=1536, act_layer='silu', + ), +) + + def _create_regnet(variant, pretrained, **kwargs): return build_model_with_cfg( RegNet, variant, pretrained, @@ -632,6 +678,220 @@ def _create_regnet(variant, pretrained, **kwargs): **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'test_input_size': (3, 288, 288), 'crop_pct': 0.95, 'test_crop_pct': 1.0, + 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +def _cfgpyc(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'license': 'mit', 'origin_url': 'https://github.com/facebookresearch/pycls', **kwargs + } + + +def _cfgtv2(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.965, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'license': 'bsd-3-clause', 'origin_url': 'https://github.com/pytorch/vision', **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # timm trained models + 'regnety_032.ra_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'), + 'regnety_040.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth'), + 'regnety_064.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth'), + 'regnety_080.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth'), + 'regnety_120.sw_in12k_ft_in1k': _cfg(hf_hub_id='timm/'), + 'regnety_160.sw_in12k_ft_in1k': _cfg(hf_hub_id='timm/'), + 'regnety_160.lion_in12k_ft_in1k': _cfg(hf_hub_id='timm/'), + + # timm in12k pretrain + 'regnety_120.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'regnety_160.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + + # timm custom arch (v and z guess) + trained models + 'regnety_040_sgn.untrained': _cfg(url=''), + 'regnetv_040.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth', + first_conv='stem'), + 'regnetv_064.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth', + first_conv='stem'), + + 'regnetz_005.untrained': _cfg(url=''), + 'regnetz_040.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)), + 'regnetz_040_h.ra3_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)), + + # used in DeiT for distillation (from Facebook DeiT GitHub repository) + 'regnety_160.deit_in1k': _cfg( + hf_hub_id='timm/', url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'), + + 'regnetx_004_tv.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth'), + 'regnetx_008.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth'), + 'regnetx_016.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth'), + 'regnetx_032.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth'), + 'regnetx_080.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth'), + 'regnetx_160.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth'), + 'regnetx_320.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth'), + + 'regnety_004.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth'), + 'regnety_008_tv.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth'), + 'regnety_016.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth'), + 'regnety_032.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth'), + 'regnety_080_tv.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth'), + 'regnety_160.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth'), + 'regnety_320.tv2_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth'), + + 'regnety_160.swag_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth', license='cc-by-nc-4.0', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_320.swag_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth', license='cc-by-nc-4.0', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_1280.swag_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth', license='cc-by-nc-4.0', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + 'regnety_160.swag_lc_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth', license='cc-by-nc-4.0'), + 'regnety_320.swag_lc_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth', license='cc-by-nc-4.0'), + 'regnety_1280.swag_lc_in1k': _cfgtv2( + hf_hub_id='timm/', + url='https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth', license='cc-by-nc-4.0'), + + 'regnety_320.seer_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + license='other', origin_url='https://github.com/facebookresearch/vissl', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_640.seer_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + license='other', origin_url='https://github.com/facebookresearch/vissl', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_1280.seer_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + license='other', origin_url='https://github.com/facebookresearch/vissl', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'regnety_2560.seer_ft_in1k': _cfgtv2( + hf_hub_id='timm/', + license='other', origin_url='https://github.com/facebookresearch/vissl', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet256_finetuned_in1k_model_final_checkpoint_phase38.torch', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + 'regnety_320.seer': _cfgtv2( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch', + num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'), + 'regnety_640.seer': _cfgtv2( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch', + num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'), + 'regnety_1280.seer': _cfgtv2( + hf_hub_id='timm/', + url='https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch', + num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'), + # FIXME invalid weight <-> model match, mistake on their end + #'regnety_2560.seer': _cfgtv2( + # url='https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_cosine_rg256gf_noBNhead_wd1e5_fairstore_bs16_node64_sinkhorn10_proto16k_apex_syncBN64_warmup8k/model_final_checkpoint_phase0.torch', + # num_classes=0, license='other', origin_url='https://github.com/facebookresearch/vissl'), + + 'regnetx_002.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_004.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_006.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_008.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_016.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_032.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_040.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_064.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_080.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_120.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_160.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnetx_320.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + + 'regnety_002.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_004.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_006.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_008.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_016.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_032.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_040.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_064.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_080.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_120.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_160.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), + 'regnety_320.pycls_in1k': _cfgpyc(hf_hub_id='timm/'), +}) + + @register_model def regnetx_002(pretrained=False, **kwargs): """RegNetX-200MF""" @@ -644,6 +904,12 @@ def regnetx_004(pretrained=False, **kwargs): return _create_regnet('regnetx_004', pretrained, **kwargs) +@register_model +def regnetx_004_tv(pretrained=False, **kwargs): + """RegNetX-400MF w/ torchvision group rounding""" + return _create_regnet('regnetx_004_tv', pretrained, **kwargs) + + @register_model def regnetx_006(pretrained=False, **kwargs): """RegNetX-600MF""" @@ -728,6 +994,12 @@ def regnety_008(pretrained=False, **kwargs): return _create_regnet('regnety_008', pretrained, **kwargs) +@register_model +def regnety_008_tv(pretrained=False, **kwargs): + """RegNetY-800MF w/ torchvision group rounding""" + return _create_regnet('regnety_008_tv', pretrained, **kwargs) + + @register_model def regnety_016(pretrained=False, **kwargs): """RegNetY-1.6GF""" @@ -758,6 +1030,12 @@ def regnety_080(pretrained=False, **kwargs): return _create_regnet('regnety_080', pretrained, **kwargs) +@register_model +def regnety_080_tv(pretrained=False, **kwargs): + """RegNetY-8.0GF w/ torchvision group rounding""" + return _create_regnet('regnety_080_tv', pretrained, **kwargs) + + @register_model def regnety_120(pretrained=False, **kwargs): """RegNetY-12GF""" @@ -795,20 +1073,20 @@ def regnety_2560(pretrained=False, **kwargs): @register_model -def regnety_040s_gn(pretrained=False, **kwargs): +def regnety_040_sgn(pretrained=False, **kwargs): """RegNetY-4.0GF w/ GroupNorm """ - return _create_regnet('regnety_040s_gn', pretrained, **kwargs) + return _create_regnet('regnety_040_sgn', pretrained, **kwargs) @register_model def regnetv_040(pretrained=False, **kwargs): - """""" + """RegNetV-4.0GF (pre-activation)""" return _create_regnet('regnetv_040', pretrained, **kwargs) @register_model def regnetv_064(pretrained=False, **kwargs): - """""" + """RegNetV-6.4GF (pre-activation)""" return _create_regnet('regnetv_064', pretrained, **kwargs) @@ -831,9 +1109,14 @@ def regnetz_040(pretrained=False, **kwargs): @register_model -def regnetz_040h(pretrained=False, **kwargs): +def regnetz_040_h(pretrained=False, **kwargs): """RegNetZ-4.0GF NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py but it's not clear it is equivalent to paper model as not detailed in the paper. """ - return _create_regnet('regnetz_040h', pretrained, zero_init_last=False, **kwargs) + return _create_regnet('regnetz_040_h', pretrained, zero_init_last=False, **kwargs) + + +register_model_deprecations(__name__, { + 'regnetz_040h': 'regnetz_040_h', +}) \ No newline at end of file