""" Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418 IdentityFormer, RandFormer, PoolFormerV2, ConvFormer, and CAFormer from MetaFormer Baselines for Vision https://arxiv.org/abs/2210.13452 All implemented models support feature extraction and variable input resolution. Original implementation by Weihao Yu et al., adapted for timm by Fredo Guan and Ross Wightman. Adapted from https://github.com/sail-sg/metaformer, original copyright below """ # Copyright 2022 Garena Online Private Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from functools import partial from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['MetaFormer'] class Stem(nn.Module): """ Stem implemented by a layer of convolution. Conv2d params constant across all models. """ def __init__( self, in_channels, out_channels, norm_layer=None, ): super().__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=7, stride=4, padding=2 ) self.norm = norm_layer(out_channels) if norm_layer else nn.Identity() def forward(self, x): x = self.conv(x) x = self.norm(x) return x class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, norm_layer=None, ): super().__init__() self.norm = norm_layer(in_channels) if norm_layer else nn.Identity() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x): x = self.norm(x) x = self.conv(x) return x class Scale(nn.Module): """ Scale vector by element multiplications. """ def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True): super().__init__() self.shape = (dim, 1, 1) if use_nchw else (dim,) self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) def forward(self, x): return x * self.scale.view(self.shape) class SquaredReLU(nn.Module): """ Squared ReLU: https://arxiv.org/abs/2109.08668 """ def __init__(self, inplace=False): super().__init__() self.relu = nn.ReLU(inplace=inplace) def forward(self, x): return torch.square(self.relu(x)) class StarReLU(nn.Module): """ StarReLU: s * relu(x) ** 2 + b """ def __init__( self, scale_value=1.0, bias_value=0.0, scale_learnable=True, bias_learnable=True, mode=None, inplace=False ): super().__init__() self.inplace = inplace self.relu = nn.ReLU(inplace=inplace) self.scale = nn.Parameter(scale_value * torch.ones(1), requires_grad=scale_learnable) self.bias = nn.Parameter(bias_value * torch.ones(1), requires_grad=bias_learnable) def forward(self, x): return self.scale * self.relu(x) ** 2 + self.bias class Attention(nn.Module): """ Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762. Modified from timm. """ fused_attn: Final[bool] def __init__( self, dim, head_dim=32, num_heads=None, qkv_bias=False, attn_drop=0., proj_drop=0., proj_bias=False, **kwargs ): super().__init__() self.head_dim = head_dim self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() self.num_heads = num_heads if num_heads else dim // head_dim if self.num_heads == 0: self.num_heads = 1 self.attention_dim = self.num_heads * self.head_dim self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p if self.training else 0., ) else: attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x # custom norm modules that disable the bias term, since the original models defs # used a custom norm with a weight term but no bias term. class GroupNorm1NoBias(GroupNorm1): def __init__(self, num_channels, **kwargs): super().__init__(num_channels, **kwargs) self.eps = kwargs.get('eps', 1e-6) self.bias = None class LayerNorm2dNoBias(LayerNorm2d): def __init__(self, num_channels, **kwargs): super().__init__(num_channels, **kwargs) self.eps = kwargs.get('eps', 1e-6) self.bias = None class LayerNormNoBias(nn.LayerNorm): def __init__(self, num_channels, **kwargs): super().__init__(num_channels, **kwargs) self.eps = kwargs.get('eps', 1e-6) self.bias = None class SepConv(nn.Module): r""" Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. """ def __init__( self, dim, expansion_ratio=2, act1_layer=StarReLU, act2_layer=nn.Identity, bias=False, kernel_size=7, padding=3, **kwargs ): super().__init__() mid_channels = int(expansion_ratio * dim) self.pwconv1 = nn.Conv2d(dim, mid_channels, kernel_size=1, bias=bias) self.act1 = act1_layer() self.dwconv = nn.Conv2d( mid_channels, mid_channels, kernel_size=kernel_size, padding=padding, groups=mid_channels, bias=bias) # depthwise conv self.act2 = act2_layer() self.pwconv2 = nn.Conv2d(mid_channels, dim, kernel_size=1, bias=bias) def forward(self, x): x = self.pwconv1(x) x = self.act1(x) x = self.dwconv(x) x = self.act2(x) x = self.pwconv2(x) return x class Pooling(nn.Module): """ Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418 """ def __init__(self, pool_size=3, **kwargs): super().__init__() self.pool = nn.AvgPool2d( pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) def forward(self, x): y = self.pool(x) return y - x class MlpHead(nn.Module): """ MLP classification head """ def __init__( self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU, norm_layer=LayerNorm, drop_rate=0., bias=True ): super().__init__() hidden_features = int(mlp_ratio * dim) self.fc1 = nn.Linear(dim, hidden_features, bias=bias) self.act = act_layer() self.norm = norm_layer(hidden_features) self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) self.head_drop = nn.Dropout(drop_rate) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.norm(x) x = self.head_drop(x) x = self.fc2(x) return x class MetaFormerBlock(nn.Module): """ Implementation of one MetaFormer block. """ def __init__( self, dim, token_mixer=Pooling, mlp_act=StarReLU, mlp_bias=False, norm_layer=LayerNorm2d, proj_drop=0., drop_path=0., use_nchw=True, layer_scale_init_value=None, res_scale_init_value=None, **kwargs ): super().__init__() ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw) rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw) self.norm1 = norm_layer(dim) self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **kwargs) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale1 = ls_layer() if layer_scale_init_value is not None else nn.Identity() self.res_scale1 = rs_layer() if res_scale_init_value is not None else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp( dim, int(4 * dim), act_layer=mlp_act, bias=mlp_bias, drop=proj_drop, use_conv=use_nchw, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale2 = ls_layer() if layer_scale_init_value is not None else nn.Identity() self.res_scale2 = rs_layer() if res_scale_init_value is not None else nn.Identity() def forward(self, x): x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( self.token_mixer(self.norm1(x)) ) ) x = self.res_scale2(x) + \ self.layer_scale2( self.drop_path2( self.mlp(self.norm2(x)) ) ) return x class MetaFormerStage(nn.Module): def __init__( self, in_chs, out_chs, depth=2, token_mixer=nn.Identity, mlp_act=StarReLU, mlp_bias=False, downsample_norm=LayerNorm2d, norm_layer=LayerNorm2d, proj_drop=0., dp_rates=[0.] * 2, layer_scale_init_value=None, res_scale_init_value=None, **kwargs, ): super().__init__() self.grad_checkpointing = False self.use_nchw = not issubclass(token_mixer, Attention) # don't downsample if in_chs and out_chs are the same self.downsample = nn.Identity() if in_chs == out_chs else Downsampling( in_chs, out_chs, kernel_size=3, stride=2, padding=1, norm_layer=downsample_norm, ) self.blocks = nn.Sequential(*[MetaFormerBlock( dim=out_chs, token_mixer=token_mixer, mlp_act=mlp_act, mlp_bias=mlp_bias, norm_layer=norm_layer, proj_drop=proj_drop, drop_path=dp_rates[i], layer_scale_init_value=layer_scale_init_value, res_scale_init_value=res_scale_init_value, use_nchw=self.use_nchw, **kwargs, ) for i in range(depth)]) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable def forward(self, x: Tensor): x = self.downsample(x) B, C, H, W = x.shape if not self.use_nchw: x = x.reshape(B, C, -1).transpose(1, 2) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) if not self.use_nchw: x = x.transpose(1, 2).reshape(B, C, H, W) return x class MetaFormer(nn.Module): r""" MetaFormer A PyTorch impl of : `MetaFormer Baselines for Vision` - https://arxiv.org/abs/2210.13452 Args: in_chans (int): Number of input image channels. num_classes (int): Number of classes for classification head. global_pool: Pooling for classifier head. depths (list or tuple): Number of blocks at each stage. dims (list or tuple): Feature dimension at each stage. token_mixers (list, tuple or token_fcn): Token mixer for each stage. mlp_act: Activation layer for MLP. mlp_bias (boolean): Enable or disable mlp bias term. drop_path_rate (float): Stochastic depth rate. drop_rate (float): Dropout rate. layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239. res_scale_init_values (list, tuple, float or None): Init value for res Scale on residual connections. None means not use the res scale. From: https://arxiv.org/abs/2110.09456. downsample_norm (nn.Module): Norm layer used in stem and downsampling layers. norm_layers (list, tuple or norm_fcn): Norm layers for each stage. output_norm: Norm layer before classifier head. use_mlp_head: Use MLP classification head. """ def __init__( self, in_chans=3, num_classes=1000, global_pool='avg', depths=(2, 2, 6, 2), dims=(64, 128, 320, 512), token_mixers=Pooling, mlp_act=StarReLU, mlp_bias=False, drop_path_rate=0., proj_drop_rate=0., drop_rate=0.0, layer_scale_init_values=None, res_scale_init_values=(None, None, 1.0, 1.0), downsample_norm=LayerNorm2dNoBias, norm_layers=LayerNorm2dNoBias, output_norm=LayerNorm2d, use_mlp_head=True, **kwargs, ): super().__init__() self.num_classes = num_classes self.num_features = dims[-1] self.drop_rate = drop_rate self.use_mlp_head = use_mlp_head self.num_stages = len(depths) # convert everything to lists if they aren't indexable if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage if not isinstance(dims, (list, tuple)): dims = [dims] if not isinstance(token_mixers, (list, tuple)): token_mixers = [token_mixers] * self.num_stages if not isinstance(norm_layers, (list, tuple)): norm_layers = [norm_layers] * self.num_stages if not isinstance(layer_scale_init_values, (list, tuple)): layer_scale_init_values = [layer_scale_init_values] * self.num_stages if not isinstance(res_scale_init_values, (list, tuple)): res_scale_init_values = [res_scale_init_values] * self.num_stages self.grad_checkpointing = False self.feature_info = [] self.stem = Stem( in_chans, dims[0], norm_layer=downsample_norm ) stages = [] prev_dim = dims[0] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] for i in range(self.num_stages): stages += [MetaFormerStage( prev_dim, dims[i], depth=depths[i], token_mixer=token_mixers[i], mlp_act=mlp_act, mlp_bias=mlp_bias, proj_drop=proj_drop_rate, dp_rates=dp_rates[i], layer_scale_init_value=layer_scale_init_values[i], res_scale_init_value=res_scale_init_values[i], downsample_norm=downsample_norm, norm_layer=norm_layers[i], **kwargs, )] prev_dim = dims[i] self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')] self.stages = nn.Sequential(*stages) # if using MlpHead, dropout is handled by MlpHead if num_classes > 0: if self.use_mlp_head: # FIXME not actually returning mlp hidden state right now as pre-logits. final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate) self.head_hidden_size = self.num_features else: final = nn.Linear(self.num_features, num_classes) self.head_hidden_size = self.num_features else: final = nn.Identity() self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), ('norm', output_norm(self.num_features)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('drop', nn.Dropout(drop_rate) if self.use_mlp_head else nn.Identity()), ('fc', final) ])) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable for stage in self.stages: stage.set_grad_checkpointing(enable=enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() if num_classes > 0: if self.use_mlp_head: final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate) else: final = nn.Linear(self.num_features, num_classes) else: final = nn.Identity() self.head.fc = final def forward_head(self, x: Tensor, pre_logits: bool = False): # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( x = self.head.global_pool(x) x = self.head.norm(x) x = self.head.flatten(x) x = self.head.drop(x) return x if pre_logits else self.head.fc(x) def forward_features(self, x: Tensor): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stages, x) else: x = self.stages(x) return x def forward(self, x: Tensor): x = self.forward_features(x) x = self.forward_head(x) return x # this works but it's long and breaks backwards compatibility with weights from the poolformer-only impl def checkpoint_filter_fn(state_dict, model): if 'stem.conv.weight' in state_dict: return state_dict import re out_dict = {} is_poolformerv1 = 'network.0.0.mlp.fc1.weight' in state_dict model_state_dict = model.state_dict() for k, v in state_dict.items(): if is_poolformerv1: k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') k = k.replace('network.3', 'downsample_layers.2') k = k.replace('network.5', 'downsample_layers.3') k = k.replace('network.2', 'network.1') k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = k.replace('downsample.proj', 'downsample.conv') k = k.replace('patch_embed.proj', 'patch_embed.conv') k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') k = k.replace('patch_embed', 'stem') k = k.replace('post_norm', 'norm') k = k.replace('pre_norm', 'norm') k = re.sub(r'^head', 'head.fc', k) k = re.sub(r'^norm', 'head.norm', k) if v.shape != model_state_dict[k] and v.numel() == model_state_dict[k].numel(): v = v.reshape(model_state_dict[k].shape) out_dict[k] = v return out_dict def _create_metaformer(variant, pretrained=False, **kwargs): default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2)))) out_indices = kwargs.pop('out_indices', default_out_indices) model = build_model_with_cfg( MetaFormer, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs, ) return model def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc', 'first_conv': 'stem.conv', **kwargs } default_cfgs = generate_default_cfgs({ 'poolformer_s12.sail_in1k': _cfg( hf_hub_id='timm/', crop_pct=0.9), 'poolformer_s24.sail_in1k': _cfg( hf_hub_id='timm/', crop_pct=0.9), 'poolformer_s36.sail_in1k': _cfg( hf_hub_id='timm/', crop_pct=0.9), 'poolformer_m36.sail_in1k': _cfg( hf_hub_id='timm/', crop_pct=0.95), 'poolformer_m48.sail_in1k': _cfg( hf_hub_id='timm/', crop_pct=0.95), 'poolformerv2_s12.sail_in1k': _cfg(hf_hub_id='timm/'), 'poolformerv2_s24.sail_in1k': _cfg(hf_hub_id='timm/'), 'poolformerv2_s36.sail_in1k': _cfg(hf_hub_id='timm/'), 'poolformerv2_m36.sail_in1k': _cfg(hf_hub_id='timm/'), 'poolformerv2_m48.sail_in1k': _cfg(hf_hub_id='timm/'), 'convformer_s18.sail_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'convformer_s18.sail_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'convformer_s18.sail_in22k_ft_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'convformer_s18.sail_in22k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', num_classes=21841), 'convformer_s36.sail_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'convformer_s36.sail_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'convformer_s36.sail_in22k_ft_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'convformer_s36.sail_in22k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', num_classes=21841), 'convformer_m36.sail_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'convformer_m36.sail_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'convformer_m36.sail_in22k_ft_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'convformer_m36.sail_in22k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', num_classes=21841), 'convformer_b36.sail_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'convformer_b36.sail_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'convformer_b36.sail_in22k_ft_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'convformer_b36.sail_in22k_ft_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'convformer_b36.sail_in22k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', num_classes=21841), 'caformer_s18.sail_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'caformer_s18.sail_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'caformer_s18.sail_in22k_ft_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'caformer_s18.sail_in22k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', num_classes=21841), 'caformer_s36.sail_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'caformer_s36.sail_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'caformer_s36.sail_in22k_ft_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'caformer_s36.sail_in22k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', num_classes=21841), 'caformer_m36.sail_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'caformer_m36.sail_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'caformer_m36.sail_in22k_ft_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'caformer_m36.sail_in22k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', num_classes=21841), 'caformer_b36.sail_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'caformer_b36.sail_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'caformer_b36.sail_in22k_ft_in1k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2'), 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)), 'caformer_b36.sail_in22k': _cfg( hf_hub_id='timm/', classifier='head.fc.fc2', num_classes=21841), }) @register_model def poolformer_s12(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], downsample_norm=None, mlp_act=nn.GELU, mlp_bias=True, norm_layers=GroupNorm1, layer_scale_init_values=1e-5, res_scale_init_values=None, use_mlp_head=False, **kwargs) return _create_metaformer('poolformer_s12', pretrained=pretrained, **model_kwargs) @register_model def poolformer_s24(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], downsample_norm=None, mlp_act=nn.GELU, mlp_bias=True, norm_layers=GroupNorm1, layer_scale_init_values=1e-5, res_scale_init_values=None, use_mlp_head=False, **kwargs) return _create_metaformer('poolformer_s24', pretrained=pretrained, **model_kwargs) @register_model def poolformer_s36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], downsample_norm=None, mlp_act=nn.GELU, mlp_bias=True, norm_layers=GroupNorm1, layer_scale_init_values=1e-6, res_scale_init_values=None, use_mlp_head=False, **kwargs) return _create_metaformer('poolformer_s36', pretrained=pretrained, **model_kwargs) @register_model def poolformer_m36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], downsample_norm=None, mlp_act=nn.GELU, mlp_bias=True, norm_layers=GroupNorm1, layer_scale_init_values=1e-6, res_scale_init_values=None, use_mlp_head=False, **kwargs) return _create_metaformer('poolformer_m36', pretrained=pretrained, **model_kwargs) @register_model def poolformer_m48(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], downsample_norm=None, mlp_act=nn.GELU, mlp_bias=True, norm_layers=GroupNorm1, layer_scale_init_values=1e-6, res_scale_init_values=None, use_mlp_head=False, **kwargs) return _create_metaformer('poolformer_m48', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_s12(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], norm_layers=GroupNorm1NoBias, use_mlp_head=False, **kwargs) return _create_metaformer('poolformerv2_s12', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_s24(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], norm_layers=GroupNorm1NoBias, use_mlp_head=False, **kwargs) return _create_metaformer('poolformerv2_s24', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_s36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], norm_layers=GroupNorm1NoBias, use_mlp_head=False, **kwargs) return _create_metaformer('poolformerv2_s36', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_m36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], norm_layers=GroupNorm1NoBias, use_mlp_head=False, **kwargs) return _create_metaformer('poolformerv2_m36', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_m48(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], norm_layers=GroupNorm1NoBias, use_mlp_head=False, **kwargs) return _create_metaformer('poolformerv2_m48', pretrained=pretrained, **model_kwargs) @register_model def convformer_s18(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, norm_layers=LayerNorm2dNoBias, **kwargs) return _create_metaformer('convformer_s18', pretrained=pretrained, **model_kwargs) @register_model def convformer_s36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, norm_layers=LayerNorm2dNoBias, **kwargs) return _create_metaformer('convformer_s36', pretrained=pretrained, **model_kwargs) @register_model def convformer_m36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=SepConv, norm_layers=LayerNorm2dNoBias, **kwargs) return _create_metaformer('convformer_m36', pretrained=pretrained, **model_kwargs) @register_model def convformer_b36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=SepConv, norm_layers=LayerNorm2dNoBias, **kwargs) return _create_metaformer('convformer_b36', pretrained=pretrained, **model_kwargs) @register_model def caformer_s18(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2, **kwargs) return _create_metaformer('caformer_s18', pretrained=pretrained, **model_kwargs) @register_model def caformer_s36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2, **kwargs) return _create_metaformer('caformer_s36', pretrained=pretrained, **model_kwargs) @register_model def caformer_m36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=[SepConv, SepConv, Attention, Attention], norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2, **kwargs) return _create_metaformer('caformer_m36', pretrained=pretrained, **model_kwargs) @register_model def caformer_b36(pretrained=False, **kwargs) -> MetaFormer: model_kwargs = dict( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=[SepConv, SepConv, Attention, Attention], norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2, **kwargs) return _create_metaformer('caformer_b36', pretrained=pretrained, **model_kwargs)