mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
InceptionNeXt using timm builder, more cleanup
This commit is contained in:
parent
f4cf9775c3
commit
3d8d7450ad
@ -1,7 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900
|
InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900
|
||||||
|
|
||||||
Some code is borrowed from timm: https://github.com/huggingface/pytorch-image-models
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -11,24 +9,31 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, DropPath, to_2tuple
|
from timm.layers import trunc_normal_, DropPath, to_2tuple
|
||||||
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
|
|
||||||
class InceptionDWConv2d(nn.Module):
|
class InceptionDWConv2d(nn.Module):
|
||||||
""" Inception depthweise convolution
|
""" Inception depthweise convolution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chs,
|
||||||
|
square_kernel_size=3,
|
||||||
|
band_kernel_size=11,
|
||||||
|
branch_ratio=0.125
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
gc = int(in_channels * branch_ratio) # channel numbers of a convolution branch
|
gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch
|
||||||
self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc)
|
self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc)
|
||||||
self.dwconv_w = nn.Conv2d(
|
self.dwconv_w = nn.Conv2d(
|
||||||
gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc)
|
gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc)
|
||||||
self.dwconv_h = nn.Conv2d(
|
self.dwconv_h = nn.Conv2d(
|
||||||
gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc)
|
gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc)
|
||||||
self.split_indexes = (in_channels - 3 * gc, gc, gc, gc)
|
self.split_indexes = (in_chs - 3 * gc, gc, gc, gc)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1)
|
x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1)
|
||||||
@ -47,8 +52,15 @@ class ConvMlp(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
|
self,
|
||||||
norm_layer=None, bias=True, drop=0.):
|
in_features,
|
||||||
|
hidden_features=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=None,
|
||||||
|
bias=True,
|
||||||
|
drop=0.,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_features = out_features or in_features
|
out_features = out_features or in_features
|
||||||
hidden_features = hidden_features or in_features
|
hidden_features = hidden_features or in_features
|
||||||
@ -69,13 +81,20 @@ class ConvMlp(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class MlpHead(nn.Module):
|
class MlpClassifierHead(nn.Module):
|
||||||
""" MLP classification head
|
""" MLP classification head
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, num_classes=1000, mlp_ratio=3, act_layer=nn.GELU,
|
self,
|
||||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0., bias=True):
|
dim,
|
||||||
|
num_classes=1000,
|
||||||
|
mlp_ratio=3,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
drop=0.,
|
||||||
|
bias=True
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_features = int(mlp_ratio * dim)
|
hidden_features = int(mlp_ratio * dim)
|
||||||
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
|
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
|
||||||
@ -168,7 +187,6 @@ class MetaNeXtStage(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
))
|
))
|
||||||
in_chs = out_chs
|
|
||||||
self.blocks = nn.Sequential(*stage_blocks)
|
self.blocks = nn.Sequential(*stage_blocks)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -209,11 +227,10 @@ class MetaNeXt(nn.Module):
|
|||||||
norm_layer=nn.BatchNorm2d,
|
norm_layer=nn.BatchNorm2d,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
mlp_ratios=(4, 4, 4, 3),
|
mlp_ratios=(4, 4, 4, 3),
|
||||||
head_fn=MlpHead,
|
head_fn=MlpClassifierHead,
|
||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
ls_init_value=1e-6,
|
ls_init_value=1e-6,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -255,6 +272,30 @@ class MetaNeXt(nn.Module):
|
|||||||
self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
|
self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def group_matcher(self, coarse=False):
|
||||||
|
return dict(
|
||||||
|
stem=r'^stem',
|
||||||
|
blocks=r'^stages\.(\d+)' if coarse else [
|
||||||
|
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
|
||||||
|
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head.fc2
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||||
|
# FIXME
|
||||||
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def set_grad_checkpointing(self, enable=True):
|
def set_grad_checkpointing(self, enable=True):
|
||||||
for s in self.stages:
|
for s in self.stages:
|
||||||
@ -262,7 +303,7 @@ class MetaNeXt(nn.Module):
|
|||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
return {'norm'}
|
return set()
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
@ -278,12 +319,6 @@ class MetaNeXt(nn.Module):
|
|||||||
x = self.forward_head(x)
|
x = self.forward_head(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
@ -291,84 +326,59 @@ def _cfg(url='', **kwargs):
|
|||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'stem.0', 'classifier': 'head.fc',
|
'first_conv': 'stem.0', 'classifier': 'head.fc2',
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = dict(
|
default_cfgs = generate_default_cfgs({
|
||||||
inception_next_tiny=_cfg(
|
'inception_next_tiny.sail_in1k': _cfg(
|
||||||
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
|
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
|
||||||
),
|
),
|
||||||
inception_next_small=_cfg(
|
'inception_next_small.sail_in1k': _cfg(
|
||||||
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
|
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
|
||||||
),
|
),
|
||||||
inception_next_base=_cfg(
|
'inception_next_base.sail_in1k': _cfg(
|
||||||
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
|
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
|
||||||
|
crop_pct=0.95,
|
||||||
),
|
),
|
||||||
inception_next_base_384=_cfg(
|
'inception_next_base.sail_in1k_384': _cfg(
|
||||||
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
|
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0,
|
input_size=(3, 384, 384), crop_pct=1.0,
|
||||||
),
|
),
|
||||||
)
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _create_inception_next(variant, pretrained=False, **kwargs):
|
||||||
|
model = build_model_with_cfg(
|
||||||
|
MetaNeXt, variant, pretrained,
|
||||||
|
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
||||||
|
**kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def inception_next_tiny(pretrained=False, **kwargs):
|
def inception_next_tiny(pretrained=False, **kwargs):
|
||||||
model = MetaNeXt(
|
model_args = dict(
|
||||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768),
|
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768),
|
||||||
token_mixers=InceptionDWConv2d,
|
token_mixers=InceptionDWConv2d,
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
model.default_cfg = default_cfgs['inception_next_tiny']
|
return _create_inception_next('inception_next_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def inception_next_small(pretrained=False, **kwargs):
|
def inception_next_small(pretrained=False, **kwargs):
|
||||||
model = MetaNeXt(
|
model_args = dict(
|
||||||
depths=(3, 3, 27, 3), dims=(96, 192, 384, 768),
|
depths=(3, 3, 27, 3), dims=(96, 192, 384, 768),
|
||||||
token_mixers=InceptionDWConv2d,
|
token_mixers=InceptionDWConv2d,
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
model.default_cfg = default_cfgs['inception_next_small']
|
return _create_inception_next('inception_next_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def inception_next_base(pretrained=False, **kwargs):
|
def inception_next_base(pretrained=False, **kwargs):
|
||||||
model = MetaNeXt(
|
model_args = dict(
|
||||||
depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024),
|
depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024),
|
||||||
token_mixers=InceptionDWConv2d,
|
token_mixers=InceptionDWConv2d,
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
model.default_cfg = default_cfgs['inception_next_base']
|
return _create_inception_next('inception_next_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def inception_next_base_384(pretrained=False, **kwargs):
|
|
||||||
model = MetaNeXt(
|
|
||||||
depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024],
|
|
||||||
mlp_ratios=[4, 4, 4, 3],
|
|
||||||
token_mixers=InceptionDWConv2d,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
model.default_cfg = default_cfgs['inception_next_base_384']
|
|
||||||
if pretrained:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user