From 3d8d7450add6d3272614447cd64e7f477288c1d5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Aug 2023 14:43:42 -0700 Subject: [PATCH] InceptionNeXt using timm builder, more cleanup --- timm/models/inception_next.py | 146 ++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 68 deletions(-) diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index f51161ae..cd34953d 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -1,7 +1,5 @@ """ InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900 - -Some code is borrowed from timm: https://github.com/huggingface/pytorch-image-models """ from functools import partial @@ -11,24 +9,31 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, to_2tuple +from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs class InceptionDWConv2d(nn.Module): """ Inception depthweise convolution """ - def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125): + def __init__( + self, + in_chs, + square_kernel_size=3, + band_kernel_size=11, + branch_ratio=0.125 + ): super().__init__() - gc = int(in_channels * branch_ratio) # channel numbers of a convolution branch + gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc) self.dwconv_w = nn.Conv2d( gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc) self.dwconv_h = nn.Conv2d( gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc) - self.split_indexes = (in_channels - 3 * gc, gc, gc, gc) + self.split_indexes = (in_chs - 3 * gc, gc, gc, gc) def forward(self, x): x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1) @@ -47,8 +52,15 @@ class ConvMlp(nn.Module): """ def __init__( - self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, - norm_layer=None, bias=True, drop=0.): + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.ReLU, + norm_layer=None, + bias=True, + drop=0., + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -69,13 +81,20 @@ class ConvMlp(nn.Module): return x -class MlpHead(nn.Module): +class MlpClassifierHead(nn.Module): """ MLP classification head """ def __init__( - self, dim, num_classes=1000, mlp_ratio=3, act_layer=nn.GELU, - norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0., bias=True): + self, + dim, + num_classes=1000, + mlp_ratio=3, + act_layer=nn.GELU, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + bias=True + ): super().__init__() hidden_features = int(mlp_ratio * dim) self.fc1 = nn.Linear(dim, hidden_features, bias=bias) @@ -168,7 +187,6 @@ class MetaNeXtStage(nn.Module): norm_layer=norm_layer, mlp_ratio=mlp_ratio, )) - in_chs = out_chs self.blocks = nn.Sequential(*stage_blocks) def forward(self, x): @@ -209,11 +227,10 @@ class MetaNeXt(nn.Module): norm_layer=nn.BatchNorm2d, act_layer=nn.GELU, mlp_ratios=(4, 4, 4, 3), - head_fn=MlpHead, + head_fn=MlpClassifierHead, drop_rate=0., drop_path_rate=0., ls_init_value=1e-6, - **kwargs, ): super().__init__() @@ -255,6 +272,30 @@ class MetaNeXt(nn.Module): self.head = head_fn(self.num_features, num_classes, drop=drop_rate) self.apply(self._init_weights) + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + ] + ) + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc2 + + def reset_classifier(self, num_classes=0, global_pool=None): + # FIXME + self.head.reset(num_classes, global_pool) + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): for s in self.stages: @@ -262,7 +303,7 @@ class MetaNeXt(nn.Module): @torch.jit.ignore def no_weight_decay(self): - return {'norm'} + return set() def forward_features(self, x): x = self.stem(x) @@ -278,12 +319,6 @@ class MetaNeXt(nn.Module): x = self.forward_head(x) return x - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - def _cfg(url='', **kwargs): return { @@ -291,84 +326,59 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.0', 'classifier': 'head.fc', + 'first_conv': 'stem.0', 'classifier': 'head.fc2', **kwargs } -default_cfgs = dict( - inception_next_tiny=_cfg( +default_cfgs = generate_default_cfgs({ + 'inception_next_tiny.sail_in1k': _cfg( url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth', ), - inception_next_small=_cfg( + 'inception_next_small.sail_in1k': _cfg( url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth', ), - inception_next_base=_cfg( + 'inception_next_base.sail_in1k': _cfg( url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth', + crop_pct=0.95, ), - inception_next_base_384=_cfg( + 'inception_next_base.sail_in1k_384': _cfg( url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth', input_size=(3, 384, 384), crop_pct=1.0, ), -) +}) + + +def _create_inception_next(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + MetaNeXt, variant, pretrained, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model @register_model def inception_next_tiny(pretrained=False, **kwargs): - model = MetaNeXt( + model_args = dict( depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), token_mixers=InceptionDWConv2d, - **kwargs ) - model.default_cfg = default_cfgs['inception_next_tiny'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_inception_next('inception_next_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def inception_next_small(pretrained=False, **kwargs): - model = MetaNeXt( + model_args = dict( depths=(3, 3, 27, 3), dims=(96, 192, 384, 768), token_mixers=InceptionDWConv2d, - **kwargs ) - model.default_cfg = default_cfgs['inception_next_small'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_inception_next('inception_next_small', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def inception_next_base(pretrained=False, **kwargs): - model = MetaNeXt( + model_args = dict( depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024), token_mixers=InceptionDWConv2d, - **kwargs ) - model.default_cfg = default_cfgs['inception_next_base'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def inception_next_base_384(pretrained=False, **kwargs): - model = MetaNeXt( - depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], - mlp_ratios=[4, 4, 4, 3], - token_mixers=InceptionDWConv2d, - **kwargs - ) - model.default_cfg = default_cfgs['inception_next_base_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url=model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_inception_next('inception_next_base', pretrained=pretrained, **dict(model_args, **kwargs))