From c78319adce5fe9b43e09d8d36abf0133effb9f33 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 20 Mar 2023 13:48:17 -0700 Subject: [PATCH] Add ImageNet-12k ReXNet-R 200 & 300 weights, and push existing ReXNet models to HF hub. Dilation support added to rexnet --- timm/models/rexnet.py | 181 +++++++++++++++++++++++++++++++----------- 1 file changed, 135 insertions(+), 46 deletions(-) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 51e8cdc2..f5c3ef67 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -21,48 +21,30 @@ from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, from ._builder import build_model_with_cfg from ._efficientnet_builder import efficientnet_init_weights from ._manipulate import checkpoint_seq -from ._registry import register_model +from ._registry import generate_default_cfgs, register_model __all__ = ['ReXNetV1'] # model_registry will add each entrypoint fn to this -def _cfg(url=''): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv', 'classifier': 'head.fc', - } - - -default_cfgs = dict( - rexnet_100=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'), - rexnet_130=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'), - rexnet_150=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'), - rexnet_200=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'), - rexnetr_100=_cfg( - url=''), - rexnetr_130=_cfg( - url=''), - rexnetr_150=_cfg( - url=''), - rexnetr_200=_cfg( - url=''), -) - SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d) class LinearBottleneck(nn.Module): def __init__( - self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, - act_layer='swish', dw_act_layer='relu6', drop_path=None): + self, + in_chs, + out_chs, + stride, + dilation=(1, 1), + exp_ratio=1.0, + se_ratio=0., + ch_div=1, + act_layer='swish', + dw_act_layer='relu6', + drop_path=None, + ): super(LinearBottleneck, self).__init__() - self.use_shortcut = stride == 1 and in_chs <= out_chs + self.use_shortcut = stride == 1 and dilation[0] == dilation[1] and in_chs <= out_chs self.in_channels = in_chs self.out_channels = out_chs @@ -73,7 +55,15 @@ class LinearBottleneck(nn.Module): dw_chs = in_chs self.conv_exp = None - self.conv_dw = ConvNormAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) + self.conv_dw = ConvNormAct( + dw_chs, + dw_chs, + kernel_size=3, + stride=stride, + dilation=dilation[0], + groups=dw_chs, + apply_act=False, + ) if se_ratio > 0: self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div)) else: @@ -102,7 +92,14 @@ class LinearBottleneck(nn.Module): return x -def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1): +def _block_cfg( + width_mult=1.0, + depth_mult=1.0, + initial_chs=16, + final_chs=180, + se_ratio=0., + ch_div=1, +): layers = [1, 2, 2, 3, 3, 5] strides = [1, 2, 2, 2, 1, 2] layers = [ceil(element * depth_mult) for element in layers] @@ -123,22 +120,45 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se def _build_blocks( - block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.): + block_cfg, + prev_chs, + width_mult, + ch_div=1, + output_stride=32, + act_layer='swish', + dw_act_layer='relu6', + drop_path_rate=0., +): feat_chs = [prev_chs] feature_info = [] curr_stride = 2 + dilation = 1 features = [] num_blocks = len(block_cfg) for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg): + next_dilation = dilation if stride > 1: fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}' feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)] - curr_stride *= stride + if curr_stride >= output_stride: + next_dilation = dilation * stride + stride = 1 block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule drop_path = DropPath(block_dpr) if block_dpr > 0. else None features.append(LinearBottleneck( - in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio, - ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path)) + in_chs=prev_chs, + out_chs=chs, + exp_ratio=exp_ratio, + stride=stride, + dilation=(dilation, next_dilation), + se_ratio=se_ratio, + ch_div=ch_div, + act_layer=act_layer, + dw_act_layer=dw_act_layer, + drop_path=drop_path, + )) + curr_stride *= stride + dilation = next_dilation prev_chs = chs feat_chs += [features[-1].feat_channels()] pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) @@ -149,23 +169,43 @@ def _build_blocks( class ReXNetV1(nn.Module): def __init__( - self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, - initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12., - ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0. + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + initial_chs=16, + final_chs=180, + width_mult=1.0, + depth_mult=1.0, + se_ratio=1/12., + ch_div=1, + act_layer='swish', + dw_act_layer='relu6', + drop_rate=0.2, + drop_path_rate=0., ): super(ReXNetV1, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False - assert output_stride == 32 # FIXME support dilation + assert output_stride in (32, 16, 8) stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer) block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div) features, self.feature_info = _build_blocks( - block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate) + block_cfg, + stem_chs, + width_mult, + ch_div, + output_stride, + act_layer, + dw_act_layer, + drop_path_rate, + ) self.num_features = features[-1].out_channels self.features = nn.Sequential(*features) @@ -201,7 +241,7 @@ class ReXNetV1(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) @@ -212,9 +252,46 @@ class ReXNetV1(nn.Module): def _create_rexnet(variant, pretrained, **kwargs): feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( - ReXNetV1, variant, pretrained, + ReXNetV1, + variant, + pretrained, feature_cfg=feature_cfg, - **kwargs) + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'license': 'mit', **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'rexnet_100.nav_in1k': _cfg(hf_hub_id='timm/'), + 'rexnet_130.nav_in1k': _cfg(hf_hub_id='timm/'), + 'rexnet_150.nav_in1k': _cfg(hf_hub_id='timm/'), + 'rexnet_200.nav_in1k': _cfg(hf_hub_id='timm/'), + 'rexnet_300.nav_in1k': _cfg(hf_hub_id='timm/'), + 'rexnetr_100.untrained': _cfg(), + 'rexnetr_130.untrained': _cfg(), + 'rexnetr_150.untrained': _cfg(), + 'rexnetr_200.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'), + 'rexnetr_300.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'), + 'rexnetr_200.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, crop_pct=1.0, license='apache-2.0'), + 'rexnetr_300.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, crop_pct=1.0, license='apache-2.0'), +}) @register_model @@ -241,6 +318,12 @@ def rexnet_200(pretrained=False, **kwargs): return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs) +@register_model +def rexnet_300(pretrained=False, **kwargs): + """ReXNet V1 3.0x""" + return _create_rexnet('rexnet_300', pretrained, width_mult=3.0, **kwargs) + + @register_model def rexnetr_100(pretrained=False, **kwargs): """ReXNet V1 1.0x w/ rounded (mod 8) channels""" @@ -263,3 +346,9 @@ def rexnetr_150(pretrained=False, **kwargs): def rexnetr_200(pretrained=False, **kwargs): """ReXNet V1 2.0x w/ rounded (mod 8) channels""" return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs) + + +@register_model +def rexnetr_300(pretrained=False, **kwargs): + """ReXNet V1 3.0x w/ rounded (mod 16) channels""" + return _create_rexnet('rexnetr_300', pretrained, width_mult=3.0, ch_div=16, **kwargs)