Add ImageNet-12k ReXNet-R 200 & 300 weights, and push existing ReXNet models to HF hub. Dilation support added to rexnet

This commit is contained in:
Ross Wightman 2023-03-20 13:48:17 -07:00
parent 8db20dc240
commit c78319adce

View File

@ -21,48 +21,30 @@ from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath,
from ._builder import build_model_with_cfg
from ._efficientnet_builder import efficientnet_init_weights
from ._manipulate import checkpoint_seq
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['ReXNetV1'] # model_registry will add each entrypoint fn to this
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
}
default_cfgs = dict(
rexnet_100=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'),
rexnet_130=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'),
rexnet_150=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'),
rexnet_200=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'),
rexnetr_100=_cfg(
url=''),
rexnetr_130=_cfg(
url=''),
rexnetr_150=_cfg(
url=''),
rexnetr_200=_cfg(
url=''),
)
SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
class LinearBottleneck(nn.Module):
def __init__(
self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1,
act_layer='swish', dw_act_layer='relu6', drop_path=None):
self,
in_chs,
out_chs,
stride,
dilation=(1, 1),
exp_ratio=1.0,
se_ratio=0.,
ch_div=1,
act_layer='swish',
dw_act_layer='relu6',
drop_path=None,
):
super(LinearBottleneck, self).__init__()
self.use_shortcut = stride == 1 and in_chs <= out_chs
self.use_shortcut = stride == 1 and dilation[0] == dilation[1] and in_chs <= out_chs
self.in_channels = in_chs
self.out_channels = out_chs
@ -73,7 +55,15 @@ class LinearBottleneck(nn.Module):
dw_chs = in_chs
self.conv_exp = None
self.conv_dw = ConvNormAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
self.conv_dw = ConvNormAct(
dw_chs,
dw_chs,
kernel_size=3,
stride=stride,
dilation=dilation[0],
groups=dw_chs,
apply_act=False,
)
if se_ratio > 0:
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
else:
@ -102,7 +92,14 @@ class LinearBottleneck(nn.Module):
return x
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1):
def _block_cfg(
width_mult=1.0,
depth_mult=1.0,
initial_chs=16,
final_chs=180,
se_ratio=0.,
ch_div=1,
):
layers = [1, 2, 2, 3, 3, 5]
strides = [1, 2, 2, 2, 1, 2]
layers = [ceil(element * depth_mult) for element in layers]
@ -123,22 +120,45 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se
def _build_blocks(
block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.):
block_cfg,
prev_chs,
width_mult,
ch_div=1,
output_stride=32,
act_layer='swish',
dw_act_layer='relu6',
drop_path_rate=0.,
):
feat_chs = [prev_chs]
feature_info = []
curr_stride = 2
dilation = 1
features = []
num_blocks = len(block_cfg)
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
next_dilation = dilation
if stride > 1:
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
curr_stride *= stride
if curr_stride >= output_stride:
next_dilation = dilation * stride
stride = 1
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
features.append(LinearBottleneck(
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path))
in_chs=prev_chs,
out_chs=chs,
exp_ratio=exp_ratio,
stride=stride,
dilation=(dilation, next_dilation),
se_ratio=se_ratio,
ch_div=ch_div,
act_layer=act_layer,
dw_act_layer=dw_act_layer,
drop_path=drop_path,
))
curr_stride *= stride
dilation = next_dilation
prev_chs = chs
feat_chs += [features[-1].feat_channels()]
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
@ -149,23 +169,43 @@ def _build_blocks(
class ReXNetV1(nn.Module):
def __init__(
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
output_stride=32,
initial_chs=16,
final_chs=180,
width_mult=1.0,
depth_mult=1.0,
se_ratio=1/12.,
ch_div=1,
act_layer='swish',
dw_act_layer='relu6',
drop_rate=0.2,
drop_path_rate=0.,
):
super(ReXNetV1, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
assert output_stride == 32 # FIXME support dilation
assert output_stride in (32, 16, 8)
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
features, self.feature_info = _build_blocks(
block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate)
block_cfg,
stem_chs,
width_mult,
ch_div,
output_stride,
act_layer,
dw_act_layer,
drop_path_rate,
)
self.num_features = features[-1].out_channels
self.features = nn.Sequential(*features)
@ -201,7 +241,7 @@ class ReXNetV1(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
@ -212,9 +252,46 @@ class ReXNetV1(nn.Module):
def _create_rexnet(variant, pretrained, **kwargs):
feature_cfg = dict(flatten_sequential=True)
return build_model_with_cfg(
ReXNetV1, variant, pretrained,
ReXNetV1,
variant,
pretrained,
feature_cfg=feature_cfg,
**kwargs)
**kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'license': 'mit', **kwargs
}
default_cfgs = generate_default_cfgs({
'rexnet_100.nav_in1k': _cfg(hf_hub_id='timm/'),
'rexnet_130.nav_in1k': _cfg(hf_hub_id='timm/'),
'rexnet_150.nav_in1k': _cfg(hf_hub_id='timm/'),
'rexnet_200.nav_in1k': _cfg(hf_hub_id='timm/'),
'rexnet_300.nav_in1k': _cfg(hf_hub_id='timm/'),
'rexnetr_100.untrained': _cfg(),
'rexnetr_130.untrained': _cfg(),
'rexnetr_150.untrained': _cfg(),
'rexnetr_200.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'),
'rexnetr_300.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'),
'rexnetr_200.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821, crop_pct=1.0, license='apache-2.0'),
'rexnetr_300.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821, crop_pct=1.0, license='apache-2.0'),
})
@register_model
@ -241,6 +318,12 @@ def rexnet_200(pretrained=False, **kwargs):
return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs)
@register_model
def rexnet_300(pretrained=False, **kwargs):
"""ReXNet V1 3.0x"""
return _create_rexnet('rexnet_300', pretrained, width_mult=3.0, **kwargs)
@register_model
def rexnetr_100(pretrained=False, **kwargs):
"""ReXNet V1 1.0x w/ rounded (mod 8) channels"""
@ -263,3 +346,9 @@ def rexnetr_150(pretrained=False, **kwargs):
def rexnetr_200(pretrained=False, **kwargs):
"""ReXNet V1 2.0x w/ rounded (mod 8) channels"""
return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs)
@register_model
def rexnetr_300(pretrained=False, **kwargs):
"""ReXNet V1 3.0x w/ rounded (mod 16) channels"""
return _create_rexnet('rexnetr_300', pretrained, width_mult=3.0, ch_div=16, **kwargs)