mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add ImageNet-12k ReXNet-R 200 & 300 weights, and push existing ReXNet models to HF hub. Dilation support added to rexnet
This commit is contained in:
parent
8db20dc240
commit
c78319adce
@ -21,48 +21,30 @@ from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath,
|
|||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._efficientnet_builder import efficientnet_init_weights
|
from ._efficientnet_builder import efficientnet_init_weights
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
__all__ = ['ReXNetV1'] # model_registry will add each entrypoint fn to this
|
__all__ = ['ReXNetV1'] # model_registry will add each entrypoint fn to this
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url=''):
|
|
||||||
return {
|
|
||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
||||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
||||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = dict(
|
|
||||||
rexnet_100=_cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'),
|
|
||||||
rexnet_130=_cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'),
|
|
||||||
rexnet_150=_cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'),
|
|
||||||
rexnet_200=_cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'),
|
|
||||||
rexnetr_100=_cfg(
|
|
||||||
url=''),
|
|
||||||
rexnetr_130=_cfg(
|
|
||||||
url=''),
|
|
||||||
rexnetr_150=_cfg(
|
|
||||||
url=''),
|
|
||||||
rexnetr_200=_cfg(
|
|
||||||
url=''),
|
|
||||||
)
|
|
||||||
|
|
||||||
SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
|
SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
|
||||||
|
|
||||||
|
|
||||||
class LinearBottleneck(nn.Module):
|
class LinearBottleneck(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1,
|
self,
|
||||||
act_layer='swish', dw_act_layer='relu6', drop_path=None):
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
stride,
|
||||||
|
dilation=(1, 1),
|
||||||
|
exp_ratio=1.0,
|
||||||
|
se_ratio=0.,
|
||||||
|
ch_div=1,
|
||||||
|
act_layer='swish',
|
||||||
|
dw_act_layer='relu6',
|
||||||
|
drop_path=None,
|
||||||
|
):
|
||||||
super(LinearBottleneck, self).__init__()
|
super(LinearBottleneck, self).__init__()
|
||||||
self.use_shortcut = stride == 1 and in_chs <= out_chs
|
self.use_shortcut = stride == 1 and dilation[0] == dilation[1] and in_chs <= out_chs
|
||||||
self.in_channels = in_chs
|
self.in_channels = in_chs
|
||||||
self.out_channels = out_chs
|
self.out_channels = out_chs
|
||||||
|
|
||||||
@ -73,7 +55,15 @@ class LinearBottleneck(nn.Module):
|
|||||||
dw_chs = in_chs
|
dw_chs = in_chs
|
||||||
self.conv_exp = None
|
self.conv_exp = None
|
||||||
|
|
||||||
self.conv_dw = ConvNormAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
|
self.conv_dw = ConvNormAct(
|
||||||
|
dw_chs,
|
||||||
|
dw_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation[0],
|
||||||
|
groups=dw_chs,
|
||||||
|
apply_act=False,
|
||||||
|
)
|
||||||
if se_ratio > 0:
|
if se_ratio > 0:
|
||||||
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
|
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
|
||||||
else:
|
else:
|
||||||
@ -102,7 +92,14 @@ class LinearBottleneck(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1):
|
def _block_cfg(
|
||||||
|
width_mult=1.0,
|
||||||
|
depth_mult=1.0,
|
||||||
|
initial_chs=16,
|
||||||
|
final_chs=180,
|
||||||
|
se_ratio=0.,
|
||||||
|
ch_div=1,
|
||||||
|
):
|
||||||
layers = [1, 2, 2, 3, 3, 5]
|
layers = [1, 2, 2, 3, 3, 5]
|
||||||
strides = [1, 2, 2, 2, 1, 2]
|
strides = [1, 2, 2, 2, 1, 2]
|
||||||
layers = [ceil(element * depth_mult) for element in layers]
|
layers = [ceil(element * depth_mult) for element in layers]
|
||||||
@ -123,22 +120,45 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se
|
|||||||
|
|
||||||
|
|
||||||
def _build_blocks(
|
def _build_blocks(
|
||||||
block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.):
|
block_cfg,
|
||||||
|
prev_chs,
|
||||||
|
width_mult,
|
||||||
|
ch_div=1,
|
||||||
|
output_stride=32,
|
||||||
|
act_layer='swish',
|
||||||
|
dw_act_layer='relu6',
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
feat_chs = [prev_chs]
|
feat_chs = [prev_chs]
|
||||||
feature_info = []
|
feature_info = []
|
||||||
curr_stride = 2
|
curr_stride = 2
|
||||||
|
dilation = 1
|
||||||
features = []
|
features = []
|
||||||
num_blocks = len(block_cfg)
|
num_blocks = len(block_cfg)
|
||||||
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
|
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
|
||||||
|
next_dilation = dilation
|
||||||
if stride > 1:
|
if stride > 1:
|
||||||
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
|
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
|
||||||
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
|
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
|
||||||
curr_stride *= stride
|
if curr_stride >= output_stride:
|
||||||
|
next_dilation = dilation * stride
|
||||||
|
stride = 1
|
||||||
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
|
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
|
||||||
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
|
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
|
||||||
features.append(LinearBottleneck(
|
features.append(LinearBottleneck(
|
||||||
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
|
in_chs=prev_chs,
|
||||||
ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path))
|
out_chs=chs,
|
||||||
|
exp_ratio=exp_ratio,
|
||||||
|
stride=stride,
|
||||||
|
dilation=(dilation, next_dilation),
|
||||||
|
se_ratio=se_ratio,
|
||||||
|
ch_div=ch_div,
|
||||||
|
act_layer=act_layer,
|
||||||
|
dw_act_layer=dw_act_layer,
|
||||||
|
drop_path=drop_path,
|
||||||
|
))
|
||||||
|
curr_stride *= stride
|
||||||
|
dilation = next_dilation
|
||||||
prev_chs = chs
|
prev_chs = chs
|
||||||
feat_chs += [features[-1].feat_channels()]
|
feat_chs += [features[-1].feat_channels()]
|
||||||
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
|
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
|
||||||
@ -149,23 +169,43 @@ def _build_blocks(
|
|||||||
|
|
||||||
class ReXNetV1(nn.Module):
|
class ReXNetV1(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
|
self,
|
||||||
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
|
in_chans=3,
|
||||||
ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.
|
num_classes=1000,
|
||||||
|
global_pool='avg',
|
||||||
|
output_stride=32,
|
||||||
|
initial_chs=16,
|
||||||
|
final_chs=180,
|
||||||
|
width_mult=1.0,
|
||||||
|
depth_mult=1.0,
|
||||||
|
se_ratio=1/12.,
|
||||||
|
ch_div=1,
|
||||||
|
act_layer='swish',
|
||||||
|
dw_act_layer='relu6',
|
||||||
|
drop_rate=0.2,
|
||||||
|
drop_path_rate=0.,
|
||||||
):
|
):
|
||||||
super(ReXNetV1, self).__init__()
|
super(ReXNetV1, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
assert output_stride == 32 # FIXME support dilation
|
assert output_stride in (32, 16, 8)
|
||||||
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
|
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
|
||||||
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
|
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
|
||||||
self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
|
self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
|
||||||
|
|
||||||
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
|
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
|
||||||
features, self.feature_info = _build_blocks(
|
features, self.feature_info = _build_blocks(
|
||||||
block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate)
|
block_cfg,
|
||||||
|
stem_chs,
|
||||||
|
width_mult,
|
||||||
|
ch_div,
|
||||||
|
output_stride,
|
||||||
|
act_layer,
|
||||||
|
dw_act_layer,
|
||||||
|
drop_path_rate,
|
||||||
|
)
|
||||||
self.num_features = features[-1].out_channels
|
self.num_features = features[-1].out_channels
|
||||||
self.features = nn.Sequential(*features)
|
self.features = nn.Sequential(*features)
|
||||||
|
|
||||||
@ -201,7 +241,7 @@ class ReXNetV1(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
return self.head(x, pre_logits=pre_logits)
|
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.forward_features(x)
|
x = self.forward_features(x)
|
||||||
@ -212,9 +252,46 @@ class ReXNetV1(nn.Module):
|
|||||||
def _create_rexnet(variant, pretrained, **kwargs):
|
def _create_rexnet(variant, pretrained, **kwargs):
|
||||||
feature_cfg = dict(flatten_sequential=True)
|
feature_cfg = dict(flatten_sequential=True)
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
ReXNetV1, variant, pretrained,
|
ReXNetV1,
|
||||||
|
variant,
|
||||||
|
pretrained,
|
||||||
feature_cfg=feature_cfg,
|
feature_cfg=feature_cfg,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||||
|
'license': 'mit', **kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = generate_default_cfgs({
|
||||||
|
'rexnet_100.nav_in1k': _cfg(hf_hub_id='timm/'),
|
||||||
|
'rexnet_130.nav_in1k': _cfg(hf_hub_id='timm/'),
|
||||||
|
'rexnet_150.nav_in1k': _cfg(hf_hub_id='timm/'),
|
||||||
|
'rexnet_200.nav_in1k': _cfg(hf_hub_id='timm/'),
|
||||||
|
'rexnet_300.nav_in1k': _cfg(hf_hub_id='timm/'),
|
||||||
|
'rexnetr_100.untrained': _cfg(),
|
||||||
|
'rexnetr_130.untrained': _cfg(),
|
||||||
|
'rexnetr_150.untrained': _cfg(),
|
||||||
|
'rexnetr_200.sw_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'),
|
||||||
|
'rexnetr_300.sw_in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'),
|
||||||
|
'rexnetr_200.sw_in12k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821, crop_pct=1.0, license='apache-2.0'),
|
||||||
|
'rexnetr_300.sw_in12k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=11821, crop_pct=1.0, license='apache-2.0'),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -241,6 +318,12 @@ def rexnet_200(pretrained=False, **kwargs):
|
|||||||
return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs)
|
return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def rexnet_300(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 3.0x"""
|
||||||
|
return _create_rexnet('rexnet_300', pretrained, width_mult=3.0, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def rexnetr_100(pretrained=False, **kwargs):
|
def rexnetr_100(pretrained=False, **kwargs):
|
||||||
"""ReXNet V1 1.0x w/ rounded (mod 8) channels"""
|
"""ReXNet V1 1.0x w/ rounded (mod 8) channels"""
|
||||||
@ -263,3 +346,9 @@ def rexnetr_150(pretrained=False, **kwargs):
|
|||||||
def rexnetr_200(pretrained=False, **kwargs):
|
def rexnetr_200(pretrained=False, **kwargs):
|
||||||
"""ReXNet V1 2.0x w/ rounded (mod 8) channels"""
|
"""ReXNet V1 2.0x w/ rounded (mod 8) channels"""
|
||||||
return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs)
|
return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def rexnetr_300(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 3.0x w/ rounded (mod 16) channels"""
|
||||||
|
return _create_rexnet('rexnetr_300', pretrained, width_mult=3.0, ch_div=16, **kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user