diff --git a/tests/test_models.py b/tests/test_models.py index 652ea355..ace88690 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -52,7 +52,7 @@ FEAT_INTER_FILTERS = [ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. @@ -60,7 +60,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*' ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index de077797..3f023572 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -1,9 +1,10 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding -from .blur_pool import BlurPool2d +from .blur_pool import BlurPool2d, create_aa from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py new file mode 100644 index 00000000..d1d38fb3 --- /dev/null +++ b/timm/layers/attention2d.py @@ -0,0 +1,337 @@ +from typing import List, Optional, Union + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .config import use_fused_attn +from .create_conv2d import create_conv2d +from .helpers import to_2tuple +from .pool2d_same import create_pool2d + + +class MultiQueryAttentionV2(nn.Module): + """Multi Query Attention. + + Fast Transformer Decoding: One Write-Head is All You Need + https://arxiv.org/pdf/1911.02150.pdf + + This is an acceletor optimized version - removing multiple unneccessary + tensor transpose by re-arranging indices according to the following rules: 1) + contracted indices are at the end, 2) other indices have the same order in the + input and output tensores. + + Compared to V1, this gives 3x speed up. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + num_heads: int = 8, + key_dim: int = 64, + value_dim: int = 64, + attn_drop: float = 0., + proj_drop: float = 0., + ): + """Initializer.""" + super().__init__() + dim_out = dim_out or dim + self.num_heads = num_heads + self.key_dim = key_dim + self.value_dim = value_dim + self.scale = key_dim ** -0.5 + + self.query_proj = nn.Parameter(torch.randn([self.num_heads, self.key_dim, dim])) + self.key_proj = nn.Parameter(torch.randn([dim, self.key_dim])) + self.value_proj = nn.Parameter(torch.randn([dim, self.value_dim])) + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Parameter(torch.randn([dim_out, self.num_heads, self.value_dim])) + self.proj_drop = nn.Dropout(proj_drop) + + def _reshape_input(self, t): + """Reshapes a tensor to three dimensions, keeping the first and last.""" + s = t.shape + # Propagate the shape statically where possible. + #num = t.shape[1:-1].numel() + #return t.reshape(s[0], num, s[-1]) + return t.reshape(s[0], s[1], -1).transpose(1, 2) + + def forward(self, x, m: Optional[torch.Tensor] = None): + """Run layer computation.""" + s = x.shape + m = m or x + + reshaped_x = self._reshape_input(x) + reshaped_m = self._reshape_input(m) + + q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj) + k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj) + + attn = torch.einsum('bnhk,bmk->bnhm', q, k) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj) + o = torch.einsum('bnhm,bmv->bnhv', attn, v) + result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj) + result = self.proj_drop(result) + return result.reshape(s) + + +class MultiQueryAttention2d(nn.Module): + """Multi Query Attention with spatial downsampling. + + 3 parameters are introduced for the spatial downsampling: + 1. kv_stride: downsampling factor on Key and Values only. + 2. query_strides: horizontal & vertical strides on Query only. + + This is an optimized version. + 1. Projections in Attention is explict written out as 1x1 Conv2D. + 2. Additional reshapes are introduced to bring a up to 3x speed up. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + num_heads: int = 8, + key_dim: Optional[int] = None, + value_dim: Optional[int] = None, + query_strides: int = 1, + kv_stride: int = 1, + dw_kernel_size: int = 3, + dilation: int = 1, + padding: Union[str, int, List[int]] = '', + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.BatchNorm2d, + use_bias: bool = False, + ): + """Initializer. + + Args: + num_heads: Number of attention heads. + key_dim: Size of the attention key dimension. + value_dim: Size of the attention value dimension. + query_strides: Vertical stride size for query only. + kv_stride: Key and value stride size. + dw_kernel_size: Spatial dimension of the depthwise kernel. + """ + super().__init__() + dim_out = dim_out or dim + self.num_heads = num_heads + self.key_dim = key_dim or dim // num_heads + self.value_dim = value_dim or dim // num_heads + self.query_strides = to_2tuple(query_strides) + self.kv_stride = kv_stride + self.has_query_strides = any([s > 1 for s in self.query_strides]) + self.scale = self.key_dim ** -0.5 + self.fused_attn = use_fused_attn() + self.drop = attn_drop + + self.query = nn.Sequential() + if self.has_query_strides: + # FIXME dilation + self.query.add_module('down_pool', create_pool2d( + 'avg', + kernel_size=self.query_strides, + padding=padding, + )) + self.query.add_module('norm', norm_layer(dim)) + self.query.add_module('proj', create_conv2d( + dim, + self.num_heads * self.key_dim, + kernel_size=1, + bias=use_bias, + )) + + self.key = nn.Sequential() + if kv_stride > 1: + self.key.add_module('down_conv', create_conv2d( + dim, + dim, + kernel_size=dw_kernel_size, + stride=kv_stride, + dilation=dilation, + padding=padding, + depthwise=True, + )) + self.key.add_module('norm', norm_layer(dim)) + self.key.add_module('proj', create_conv2d( + dim, + self.key_dim, + kernel_size=1, + padding=padding, + bias=use_bias, + )) + + self.value = nn.Sequential() + if kv_stride > 1: + self.value.add_module('down_conv', create_conv2d( + dim, + dim, + kernel_size=dw_kernel_size, + stride=kv_stride, + dilation=dilation, + padding=padding, + depthwise=True, + )) + self.value.add_module('norm', norm_layer(dim)) + self.value.add_module('proj', create_conv2d( + dim, + self.value_dim, + kernel_size=1, + bias=use_bias, + )) + + self.attn_drop = nn.Dropout(attn_drop) + + self.output = nn.Sequential() + if self.has_query_strides: + self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False)) + self.output.add_module('proj', create_conv2d( + self.value_dim * self.num_heads, + dim_out, + kernel_size=1, + bias=use_bias, + )) + self.output.add_module('drop', nn.Dropout(proj_drop)) + + self.einsum = False + + def _reshape_input(self, t: torch.Tensor): + """Reshapes a tensor to three dimensions, keeping the batch and channels.""" + s = t.shape + t = t.reshape(s[0], s[1], -1).transpose(1, 2) + if self.einsum: + return t + else: + return t.unsqueeze(1).contiguous() + + def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int): + """Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k].""" + s = t.shape + t = t.reshape(s[0], num_heads, key_dim, -1) + if self.einsum: + return t.permute(0, 3, 1, 2).contiguous() + else: + return t.transpose(-1, -2).contiguous() + + def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int): + """Reshape output:[b, n x n x h, k] -> [b, n, n, hk].""" + s = t.shape + feat_dim = s[-1] * num_heads + if not self.einsum: + t = t.transpose(1, 2) + return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous() + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + """Run layer computation.""" + B, C, H, W = s = x.shape + + q = self.query(x) + # desired q shape: [b, h, k, n x n] - [b, l, h, k] + q = self._reshape_projected_query(q, self.num_heads, self.key_dim) + + k = self.key(x) + # output shape of k: [b, k, p], p = m x m + k = self._reshape_input(k) + + v = self.value(x) + # output shape of v: [ b, p, k], p = m x m + v = self._reshape_input(v) + + # desired q shape: [b, n x n, h, k] + # desired k shape: [b, m x m, k] + # desired logits shape: [b, n x n, h, m x m] + if self.einsum: + attn = torch.einsum('blhk,bpk->blhp', q, k) * self.scale + if attn_mask is not None: + # NOTE: assumes mask is float and in correct shape + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + o = torch.einsum('blhp,bpk->blhk', attn, v) + else: + if self.fused_attn: + o = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0. + ) + else: + q = q * self.scale + attn = q @ k.transpose(-1, -2) + if attn_mask is not None: + # NOTE: assumes mask is float and in correct shape + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + o = attn @ v + + # reshape o into [b, hk, n, n,] + o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]) + x = self.output(o) + return x + + +class Attention2d(nn.Module): + fused_attn: torch.jit.Final[bool] + + """ multi-head attention for 2D NCHW tensors""" + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + num_heads: int = 32, + bias: bool = True, + expand_first: bool = False, + head_first: bool = False, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first else dim + self.num_heads = num_heads + self.dim_head = dim_attn // num_heads + self.head_first = head_first + self.scale = num_heads ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + B, C, H, W = x.shape + + if self.head_first: + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) + else: + q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) + + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q.transpose(-1, -2).contiguous(), + k.transpose(-1, -2).contiguous(), + v.transpose(-1, -2).contiguous(), + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ).transpose(-1, -2).reshape(B, -1, H, W) + else: + q = q * self.scale + attn = q.transpose(-2, -1) @ k + if attn_mask is not None: + # NOTE: assumes mask is float and in correct shape + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/timm/layers/blur_pool.py b/timm/layers/blur_pool.py index e73d8863..6a4b668c 100644 --- a/timm/layers/blur_pool.py +++ b/timm/layers/blur_pool.py @@ -5,12 +5,16 @@ BlurPool layer inspired by Hacked together by Chris Ha and Ross Wightman """ +from functools import partial +from typing import Optional, Type import torch import torch.nn as nn import torch.nn.functional as F import numpy as np + from .padding import get_padding +from .typing import LayerType class BlurPool2d(nn.Module): @@ -26,17 +30,62 @@ class BlurPool2d(nn.Module): Returns: torch.Tensor: the transformed tensor. """ - def __init__(self, channels, filt_size=3, stride=2) -> None: + def __init__( + self, + channels: Optional[int] = None, + filt_size: int = 3, + stride: int = 2, + pad_mode: str = 'reflect', + ) -> None: super(BlurPool2d, self).__init__() assert filt_size > 1 self.channels = channels self.filt_size = filt_size self.stride = stride + self.pad_mode = pad_mode self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) - blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :] + if channels is not None: + blur_filter = blur_filter.repeat(self.channels, 1, 1, 1) self.register_buffer('filt', blur_filter, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = F.pad(x, self.padding, 'reflect') - return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) + x = F.pad(x, self.padding, mode=self.pad_mode) + if self.channels is None: + channels = x.shape[1] + weight = self.filt.expand(channels, 1, self.filt_size, self.filt_size) + else: + channels = self.channels + weight = self.filt + return F.conv2d(x, weight, stride=self.stride, groups=channels) + + +def create_aa( + aa_layer: LayerType, + channels: Optional[int] = None, + stride: int = 2, + enable: bool = True, + noop: Optional[Type[nn.Module]] = nn.Identity +) -> nn.Module: + """ Anti-aliasing """ + if not aa_layer or not enable: + return noop() if noop is not None else None + + if isinstance(aa_layer, str): + aa_layer = aa_layer.lower().replace('_', '').replace('-', '') + if aa_layer == 'avg' or aa_layer == 'avgpool': + aa_layer = nn.AvgPool2d + elif aa_layer == 'blur' or aa_layer == 'blurpool': + aa_layer = BlurPool2d + elif aa_layer == 'blurpc': + aa_layer = partial(BlurPool2d, pad_mode='constant') + + else: + assert False, f"Unknown anti-aliasing layer ({aa_layer})." + + try: + return aa_layer(channels=channels, stride=stride) + except TypeError as e: + return aa_layer(stride) diff --git a/timm/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py index 84aaf4bf..de738045 100644 --- a/timm/layers/conv_bn_act.py +++ b/timm/layers/conv_bn_act.py @@ -2,9 +2,12 @@ Hacked together by / Copyright 2020 Ross Wightman """ -import functools +from typing import Any, Dict, Optional, Type + from torch import nn as nn +from .typing import LayerType, PadType +from .blur_pool import create_aa from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer @@ -12,41 +15,58 @@ from .create_norm_act import get_norm_act_layer class ConvNormAct(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding='', - dilation=1, - groups=1, - bias=False, - apply_act=True, - norm_layer=nn.BatchNorm2d, - norm_kwargs=None, - act_layer=nn.ReLU, - act_kwargs=None, - drop_layer=None, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: PadType = '', + dilation: int = 1, + groups: int = 1, + bias: bool = False, + apply_norm: bool = True, + apply_act: bool = True, + norm_layer: LayerType = nn.BatchNorm2d, + act_layer: LayerType = nn.ReLU, + drop_layer: Optional[Type[nn.Module]] = None, + conv_kwargs: Optional[Dict[str, Any]] = None, + norm_kwargs: Optional[Dict[str, Any]] = None, + act_kwargs: Optional[Dict[str, Any]] = None, ): super(ConvNormAct, self).__init__() + conv_kwargs = conv_kwargs or {} norm_kwargs = norm_kwargs or {} act_kwargs = act_kwargs or {} self.conv = create_conv2d( - in_channels, out_channels, kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=bias) - - # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer( + in_channels, out_channels, - apply_act=apply_act, - act_kwargs=act_kwargs, - **norm_kwargs, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **conv_kwargs, ) + if apply_norm: + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn = norm_act_layer( + out_channels, + apply_act=apply_act, + act_kwargs=act_kwargs, + **norm_kwargs, + ) + else: + self.bn = nn.Sequential() + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn.add_module('drop', drop_layer()) + @property def in_channels(self): return self.conv.in_channels @@ -64,54 +84,61 @@ class ConvNormAct(nn.Module): ConvBnAct = ConvNormAct -def create_aa(aa_layer, channels, stride=2, enable=True): - if not aa_layer or not enable: - return nn.Identity() - if isinstance(aa_layer, functools.partial): - if issubclass(aa_layer.func, nn.AvgPool2d): - return aa_layer() - else: - return aa_layer(channels) - elif issubclass(aa_layer, nn.AvgPool2d): - return aa_layer(stride) - else: - return aa_layer(channels=channels, stride=stride) - - class ConvNormActAa(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding='', - dilation=1, - groups=1, - bias=False, - apply_act=True, - norm_layer=nn.BatchNorm2d, - norm_kwargs=None, - act_layer=nn.ReLU, - act_kwargs=None, - aa_layer=None, - drop_layer=None, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: PadType = '', + dilation: int = 1, + groups: int = 1, + bias: bool = False, + apply_norm: bool = True, + apply_act: bool = True, + norm_layer: LayerType = nn.BatchNorm2d, + act_layer: LayerType = nn.ReLU, + aa_layer: Optional[LayerType] = None, + drop_layer: Optional[Type[nn.Module]] = None, + conv_kwargs: Optional[Dict[str, Any]] = None, + norm_kwargs: Optional[Dict[str, Any]] = None, + act_kwargs: Optional[Dict[str, Any]] = None, ): super(ConvNormActAa, self).__init__() use_aa = aa_layer is not None and stride == 2 + conv_kwargs = conv_kwargs or {} norm_kwargs = norm_kwargs or {} act_kwargs = act_kwargs or {} self.conv = create_conv2d( - in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, - padding=padding, dilation=dilation, groups=groups, bias=bias) + in_channels, out_channels, kernel_size, + stride=1 if use_aa else stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **conv_kwargs, + ) + + if apply_norm: + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn = norm_act_layer( + out_channels, + apply_act=apply_act, + act_kwargs=act_kwargs, + **norm_kwargs, + ) + else: + self.bn = nn.Sequential() + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn.add_module('drop', drop_layer()) - # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs) self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index 49505c58..496efcfd 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -19,21 +19,18 @@ from torch import nn as nn from torch.nn import functional as F from torchvision.ops.misc import FrozenBatchNorm2d -from .create_act import get_act_layer +from .create_act import create_act_layer from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm from .trace_utils import _assert def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True): - act_layer = get_act_layer(act_layer) # string -> nn.Module act_kwargs = act_kwargs or {} - if act_layer is not None and apply_act: - if inplace: - act_kwargs['inplace'] = inplace - act = act_layer(**act_kwargs) - else: - act = nn.Identity() - return act + act_kwargs.setdefault('inplace', inplace) + act = None + if apply_act: + act = create_act_layer(act_layer, **act_kwargs) + return nn.Identity() if act is None else act class BatchNormAct2d(nn.BatchNorm2d): @@ -421,7 +418,6 @@ class LayerNormAct(nn.LayerNorm): ): super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() - act_layer = get_act_layer(act_layer) # string -> nn.Module self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() diff --git a/timm/models/__init__.py b/timm/models/__init__.py index ed4df651..959b0a61 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -71,6 +71,7 @@ from .vision_transformer import * from .vision_transformer_hybrid import * from .vision_transformer_relpos import * from .vision_transformer_sam import * +from .vitamin import * from .volo import * from .vovnet import * from .xception import * diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index a5a6f30b..f33dacd5 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -2,18 +2,24 @@ Hacked together by / Copyright 2019, Ross Wightman """ +from typing import Callable, Dict, Optional, Type import torch import torch.nn as nn from torch.nn import functional as F -from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer +from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\ + ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d __all__ = [ - 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] + 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', + 'UniversalInvertedResidual', 'MobileAttention' +] + +ModuleType = Type[nn.Module] -def num_groups(group_size, channels): +def num_groups(group_size: Optional[int], channels: int): if not group_size: # 0 or None return 1 # normal conv with 1 group else: @@ -35,8 +41,15 @@ class SqueezeExcite(nn.Module): """ def __init__( - self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU, - gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None): + self, + in_chs: int, + rd_ratio: float = 0.25, + rd_channels: Optional[int] = None, + act_layer: LayerType = nn.ReLU, + gate_layer: LayerType = nn.Sigmoid, + force_act_layer: Optional[LayerType] = None, + rd_round_fn: Optional[Callable] = None, + ): super(SqueezeExcite, self).__init__() if rd_channels is None: rd_round_fn = rd_round_fn or round @@ -59,16 +72,32 @@ class ConvBnAct(nn.Module): """ Conv + Norm Layer + Activation w/ optional skip connection """ def __init__( - self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='', - skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + pad_type: str = '', + skip: bool = False, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + drop_path_rate: float = 0., + ): super(ConvBnAct, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) groups = num_groups(group_size, in_chs) self.has_skip = skip and stride == 1 and in_chs == out_chs + use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation self.conv = create_conv2d( - in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) + in_chs, out_chs, kernel_size, + stride=1 if use_aa else stride, + dilation=dilation, groups=groups, padding=pad_type) self.bn1 = norm_act_layer(out_chs, inplace=True) + self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa) self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): @@ -81,29 +110,64 @@ class ConvBnAct(nn.Module): shortcut = x x = self.conv(x) x = self.bn1(x) + x = self.aa(x) if self.has_skip: x = self.drop_path(x) + shortcut return x class DepthwiseSeparableConv(nn.Module): - """ DepthwiseSeparable block + """ Depthwise-separable block Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. """ def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', - noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - se_layer=None, drop_path_rate=0.): + self, + in_chs: int, + out_chs: int, + dw_kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + noskip: bool = False, + pw_kernel_size: int = 1, + pw_act: bool = False, + s2d: int = 0, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + drop_path_rate: float = 0., + ): super(DepthwiseSeparableConv, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - groups = num_groups(group_size, in_chs) self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv + use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation + + # Space to depth + if s2d == 1: + sd_chs = int(in_chs * 4) + self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') + self.bn_s2d = norm_act_layer(sd_chs, sd_chs) + dw_kernel_size = (dw_kernel_size + 1) // 2 + dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type + in_chs = sd_chs + use_aa = False # disable AA + else: + self.conv_s2d = None + self.bn_s2d = None + dw_pad_type = pad_type + + groups = num_groups(group_size, in_chs) self.conv_dw = create_conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups) + in_chs, in_chs, dw_kernel_size, + stride=1 if use_aa else stride, + dilation=dilation, padding=dw_pad_type, groups=groups) self.bn1 = norm_act_layer(in_chs, inplace=True) + self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa) # Squeeze-and-excitation self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() @@ -120,8 +184,12 @@ class DepthwiseSeparableConv(nn.Module): def forward(self, x): shortcut = x + if self.conv_s2d is not None: + x = self.conv_s2d(x) + x = self.bn_s2d(x) x = self.conv_dw(x) x = self.bn1(x) + x = self.aa(x) x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) @@ -141,15 +209,48 @@ class InvertedResidual(nn.Module): """ def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', - noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): + self, + in_chs: int, + out_chs: int, + dw_kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + noskip: bool = False, + exp_ratio: float = 1.0, + exp_kernel_size: int = 1, + pw_kernel_size: int = 1, + s2d: int = 0, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + conv_kwargs: Optional[Dict] = None, + drop_path_rate: float = 0., + ): super(InvertedResidual, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) conv_kwargs = conv_kwargs or {} + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation + + # Space to depth + if s2d == 1: + sd_chs = int(in_chs * 4) + self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') + self.bn_s2d = norm_act_layer(sd_chs, sd_chs) + dw_kernel_size = (dw_kernel_size + 1) // 2 + dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type + in_chs = sd_chs + use_aa = False # disable AA + else: + self.conv_s2d = None + self.bn_s2d = None + dw_pad_type = pad_type + mid_chs = make_divisible(in_chs * exp_ratio) groups = num_groups(group_size, mid_chs) - self.has_skip = (in_chs == out_chs and stride == 1) and not noskip # Point-wise expansion self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) @@ -157,9 +258,11 @@ class InvertedResidual(nn.Module): # Depth-wise convolution self.conv_dw = create_conv2d( - mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, - groups=groups, padding=pad_type, **conv_kwargs) + mid_chs, mid_chs, dw_kernel_size, + stride=1 if use_aa else stride, + dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs) self.bn2 = norm_act_layer(mid_chs, inplace=True) + self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa) # Squeeze-and-excitation self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() @@ -177,10 +280,14 @@ class InvertedResidual(nn.Module): def forward(self, x): shortcut = x + if self.conv_s2d is not None: + x = self.conv_s2d(x) + x = self.bn_s2d(x) x = self.conv_pw(x) x = self.bn1(x) x = self.conv_dw(x) x = self.bn2(x) + x = self.aa(x) x = self.se(x) x = self.conv_pwl(x) x = self.bn3(x) @@ -189,23 +296,317 @@ class InvertedResidual(nn.Module): return x +class LayerScale2d(nn.Module): + def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class UniversalInvertedResidual(nn.Module): + """ Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB) + + For MobileNetV4 - https://arxiv.org/abs/, referenced from + https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778 + """ + + def __init__( + self, + in_chs: int, + out_chs: int, + dw_kernel_size_start: int = 0, + dw_kernel_size_mid: int = 3, + dw_kernel_size_end: int = 0, + stride: int = 1, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + noskip: bool = False, + exp_ratio: float = 1.0, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + conv_kwargs: Optional[Dict] = None, + drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = 1e-5, + ): + super(UniversalInvertedResidual, self).__init__() + conv_kwargs = conv_kwargs or {} + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + if stride > 1: + assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end + + # FIXME dilation isn't right w/ extra ks > 1 convs + if dw_kernel_size_start: + dw_start_stride = stride if not dw_kernel_size_mid else 1 + dw_start_groups = num_groups(group_size, in_chs) + self.dw_start = ConvNormActAa( + in_chs, in_chs, dw_kernel_size_start, + stride=dw_start_stride, + dilation=dilation, # FIXME + groups=dw_start_groups, + padding=pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, + **conv_kwargs, + ) + else: + self.dw_start = nn.Identity() + + # Point-wise expansion + mid_chs = make_divisible(in_chs * exp_ratio) + self.pw_exp = ConvNormAct( + in_chs, mid_chs, 1, + padding=pad_type, + act_layer=act_layer, + norm_layer=norm_layer, + **conv_kwargs, + ) + + # Middle depth-wise convolution + if dw_kernel_size_mid: + groups = num_groups(group_size, mid_chs) + self.dw_mid = ConvNormActAa( + mid_chs, mid_chs, dw_kernel_size_mid, + stride=stride, + dilation=dilation, # FIXME + groups=groups, + padding=pad_type, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, + **conv_kwargs, + ) + else: + # keeping mid as identity so it can be hooked more easily for features + self.dw_mid = nn.Identity() + + # Squeeze-and-excitation + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + + # Point-wise linear projection + self.pw_proj = ConvNormAct( + mid_chs, out_chs, 1, + padding=pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, + **conv_kwargs, + ) + + if dw_kernel_size_end: + dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1 + dw_end_groups = num_groups(group_size, out_chs) + if dw_end_stride > 1: + assert not aa_layer + self.dw_end = ConvNormAct( + out_chs, out_chs, dw_kernel_size_end, + stride=dw_end_stride, + dilation=dilation, + groups=dw_end_groups, + padding=pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, + **conv_kwargs, + ) + else: + self.dw_end = nn.Identity() + + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) + else: + self.layer_scale = nn.Identity() + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PWL + return dict(module='pw_proj.conv', hook_type='forward_pre', num_chs=self.pw_proj.conv.in_channels) + else: # location == 'bottleneck', block output + return dict(module='', num_chs=self.pw_proj.conv.out_channels) + + def forward(self, x): + shortcut = x + x = self.dw_start(x) + x = self.pw_exp(x) + x = self.dw_mid(x) + x = self.se(x) + x = self.pw_proj(x) + x = self.dw_end(x) + x = self.layer_scale(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + return x + + +class MobileAttention(nn.Module): + """ Mobile Attention Block + + For MobileNetV4 - https://arxiv.org/abs/, referenced from + https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504 + """ + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dw_kernel_size: int = 3, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + num_heads: int = 8, + key_dim: int = 64, + value_dim: int = 64, + use_multi_query: bool = False, + query_strides: int = (1, 1), + kv_stride: int = 1, + cpe_dw_kernel_size: int = 3, + noskip: bool = False, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + drop_path_rate: float = 0., + attn_drop: float = 0.0, + proj_drop: float = 0.0, + layer_scale_init_value: Optional[float] = 1e-5, + use_bias: bool = False, + use_cpe: bool = False, + ): + super(MobileAttention, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip + self.query_strides = to_2tuple(query_strides) + self.kv_stride = kv_stride + self.has_query_stride = any([s > 1 for s in self.query_strides]) + + # This CPE is different than the one suggested in the original paper. + # https://arxiv.org/abs/2102.10882 + # 1. Rather than adding one CPE before the attention blocks, we add a CPE + # into every attention block. + # 2. We replace the expensive Conv2D by a Seperable DW Conv. + if use_cpe: + self.conv_cpe_dw = create_conv2d( + in_chs, in_chs, + kernel_size=cpe_dw_kernel_size, + dilation=dilation, + depthwise=True, + bias=True, + ) + else: + self.conv_cpe_dw = None + + self.norm = norm_act_layer(in_chs, apply_act=False) + + if num_heads is None: + assert in_chs % key_dim == 0 + num_heads = in_chs // key_dim + + if use_multi_query: + self.attn = MultiQueryAttention2d( + in_chs, + dim_out=out_chs, + num_heads=num_heads, + key_dim=key_dim, + value_dim=value_dim, + query_strides=query_strides, + kv_stride=kv_stride, + dilation=dilation, + padding=pad_type, + dw_kernel_size=dw_kernel_size, + attn_drop=attn_drop, + proj_drop=proj_drop, + #bias=use_bias, # why not here if used w/ mhsa? + ) + else: + self.attn = Attention2d( + in_chs, + dim_out=out_chs, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + bias=use_bias, + ) + + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) + else: + self.layer_scale = nn.Identity() + + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PW + return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) + else: # location == 'bottleneck', block output + return dict(module='', num_chs=self.conv_pw.out_channels) + + def forward(self, x): + if self.conv_cpe_dw is not None: + x_cpe = self.conv_cpe_dw(x) + x = x + x_cpe + + shortcut = x + x = self.norm(x) + x = self.attn(x) + x = self.layer_scale(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + + return x + + class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', - noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): + self, + in_chs: int, + out_chs: int, + dw_kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + noskip: bool = False, + exp_ratio: float = 1.0, + exp_kernel_size: int = 1, + pw_kernel_size: int = 1, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + num_experts: int = 0, + drop_path_rate: float = 0., + ): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) - super(CondConvResidual, self).__init__( - in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size, - pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, - drop_path_rate=drop_path_rate) - + in_chs, + out_chs, + dw_kernel_size=dw_kernel_size, + stride=stride, + dilation=dilation, + group_size=group_size, + pad_type=pad_type, + noskip=noskip, + exp_ratio=exp_ratio, + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, + se_layer=se_layer, + conv_kwargs=conv_kwargs, + drop_path_rate=drop_path_rate, + ) self.routing_fn = nn.Linear(in_chs, self.num_experts) def forward(self, x): @@ -237,9 +638,24 @@ class EdgeResidual(nn.Module): """ def __init__( - self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='', - force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + self, + in_chs: int, + out_chs: int, + exp_kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + pad_type: str = '', + force_in_chs: int = 0, + noskip: bool = False, + exp_ratio: float = 1.0, + pw_kernel_size: int = 1, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + drop_path_rate: float = 0., + ): super(EdgeResidual, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) if force_in_chs > 0: @@ -248,12 +664,17 @@ class EdgeResidual(nn.Module): mid_chs = make_divisible(in_chs * exp_ratio) groups = num_groups(group_size, in_chs) self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation # Expansion convolution self.conv_exp = create_conv2d( - in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) + in_chs, mid_chs, exp_kernel_size, + stride=1 if use_aa else stride, + dilation=dilation, groups=groups, padding=pad_type) self.bn1 = norm_act_layer(mid_chs, inplace=True) + self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa) + # Squeeze-and-excitation self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() @@ -272,6 +693,7 @@ class EdgeResidual(nn.Module): shortcut = x x = self.conv_exp(x) x = self.bn1(x) + x = self.aa(x) x = self.se(x) x = self.conv_pwl(x) x = self.bn2(x) diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index 1e3161d6..57bde323 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -5,6 +5,7 @@ Handles stride, dilation calculations, and selects feature extraction points. Hacked together by / Copyright 2019, Ross Wightman """ +from typing import Callable, Optional import logging import math @@ -16,7 +17,7 @@ from typing import Any, Dict, List import torch.nn as nn from ._efficientnet_blocks import * -from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible +from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] @@ -139,8 +140,8 @@ def _decode_block_str(block_str): # if act_layer is None, the model default (passed to model init) will be used act_layer = options['n'] if 'n' in options else None - exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 - pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def num_repeat = int(options['r']) @@ -154,29 +155,31 @@ def _decode_block_str(block_str): if block_type == 'ir': block_args.update(dict( dw_kernel_size=_parse_ksize(options['k']), - exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, + exp_kernel_size=start_kernel_size, + pw_kernel_size=end_kernel_size, exp_ratio=float(options['e']), - se_ratio=float(options['se']) if 'se' in options else 0., + se_ratio=float(options.get('se', 0.)), noskip=skip is False, + s2d=int(options.get('d', 0)) > 0, )) if 'cc' in options: block_args['num_experts'] = int(options['cc']) elif block_type == 'ds' or block_type == 'dsa': block_args.update(dict( dw_kernel_size=_parse_ksize(options['k']), - pw_kernel_size=pw_kernel_size, - se_ratio=float(options['se']) if 'se' in options else 0., + pw_kernel_size=end_kernel_size, + se_ratio=float(options.get('se', 0.)), pw_act=block_type == 'dsa', noskip=block_type == 'dsa' or skip is False, + s2d=int(options.get('d', 0)) > 0, )) elif block_type == 'er': block_args.update(dict( exp_kernel_size=_parse_ksize(options['k']), - pw_kernel_size=pw_kernel_size, + pw_kernel_size=end_kernel_size, exp_ratio=float(options['e']), force_in_chs=force_in_chs, - se_ratio=float(options['se']) if 'se' in options else 0., + se_ratio=float(options.get('se', 0.)), noskip=skip is False, )) elif block_type == 'cn': @@ -184,6 +187,38 @@ def _decode_block_str(block_str): kernel_size=int(options['k']), skip=skip is True, )) + elif block_type == 'uir': + # override exp / proj kernels for start/end in uir block + start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0 + end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0 + block_args.update(dict( + dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start + dw_kernel_size_mid=_parse_ksize(options['k']), + dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end + exp_ratio=float(options['e']), + se_ratio=float(options.get('se', 0.)), + noskip=skip is False, + )) + elif block_type == 'mha': + kv_dim = int(options['d']) + block_args.update(dict( + dw_kernel_size=_parse_ksize(options['k']), + num_heads=int(options['h']), + key_dim=kv_dim, + value_dim=kv_dim, + kv_stride=int(options.get('v', 1)), + noskip=skip is False, + )) + elif block_type == 'mqa': + kv_dim = int(options['d']) + block_args.update(dict( + dw_kernel_size=_parse_ksize(options['k']), + num_heads=int(options['h']), + key_dim=kv_dim, + value_dim=kv_dim, + kv_stride=int(options.get('v', 1)), + noskip=skip is False, + )) else: assert False, 'Unknown block type (%s)' % block_type if 'gs' in options: @@ -285,14 +320,27 @@ class EfficientNetBuilder: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py """ - def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False, - act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''): + def __init__( + self, + output_stride: int = 32, + pad_type: str = '', + round_chs_fn: Callable = round_channels, + se_from_exp: bool = False, + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, + feature_location: str = '', + ): self.output_stride = output_stride self.pad_type = pad_type self.round_chs_fn = round_chs_fn self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs self.act_layer = act_layer self.norm_layer = norm_layer + self.aa_layer = aa_layer self.se_layer = get_attn(se_layer) try: self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg @@ -300,6 +348,7 @@ class EfficientNetBuilder: except TypeError: self.se_has_ratio = False self.drop_path_rate = drop_path_rate + self.layer_scale_init_value = layer_scale_init_value if feature_location == 'depthwise': # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") @@ -317,6 +366,10 @@ class EfficientNetBuilder: bt = ba.pop('block_type') ba['in_chs'] = self.in_chs ba['out_chs'] = self.round_chs_fn(ba['out_chs']) + s2d = ba.get('s2d', 0) + if s2d > 0: + # adjust while space2depth active + ba['out_chs'] *= 4 if 'force_in_chs' in ba and ba['force_in_chs']: # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs']) @@ -326,16 +379,22 @@ class EfficientNetBuilder: assert ba['act_layer'] is not None ba['norm_layer'] = self.norm_layer ba['drop_path_rate'] = drop_path_rate - if bt != 'cn': - se_ratio = ba.pop('se_ratio') - if se_ratio and self.se_layer is not None: - if not self.se_from_exp: - # adjust se_ratio by expansion ratio if calculating se channels from block input - se_ratio /= ba.get('exp_ratio', 1.0) - if self.se_has_ratio: - ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) - else: - ba['se_layer'] = self.se_layer + + if self.aa_layer is not None: + ba['aa_layer'] = self.aa_layer + + se_ratio = ba.pop('se_ratio', None) + if se_ratio and self.se_layer is not None: + if not self.se_from_exp: + # adjust se_ratio by expansion ratio if calculating se channels from block input + se_ratio /= ba.get('exp_ratio', 1.0) + if s2d == 1: + # adjust for start of space2depth + se_ratio /= 4 + if self.se_has_ratio: + ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) + else: + ba['se_layer'] = self.se_layer if bt == 'ir': _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) @@ -349,8 +408,17 @@ class EfficientNetBuilder: elif bt == 'cn': _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = ConvBnAct(**ba) + elif bt == 'uir': + _log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value) + elif bt == 'mqa': + _log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value) + elif bt == 'mha': + _log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value) else: - assert False, 'Uknkown block type (%s) while building model.' % bt + assert False, 'Unknown block type (%s) while building model.' % bt self.in_chs = ba['out_chs'] # update in_chs for arg of next block return block @@ -377,6 +445,7 @@ class EfficientNetBuilder: self.features.append(feature_info) # outer list of block_args defines the stacks + space2depth = 0 for stack_idx, stack_args in enumerate(model_block_args): last_stack = stack_idx + 1 == len(model_block_args) _log_info_if('Stack: {}'.format(stack_idx), self.verbose) @@ -392,6 +461,20 @@ class EfficientNetBuilder: if block_idx >= 1: # only the first block in any stack can have a stride > 1 block_args['stride'] = 1 + if not space2depth and block_args.pop('s2d', False): + assert block_args['stride'] == 1 + space2depth = 1 + + if space2depth > 0: + # FIXME s2d is a WIP + if space2depth == 2 and block_args['stride'] == 2: + block_args['stride'] = 1 + # to end s2d region, need to correct expansion and se ratio relative to input + block_args['exp_ratio'] /= 4 + space2depth = 0 + else: + block_args['s2d'] = space2depth + extract_features = False if last_block: next_stack_idx = stack_idx + 1 @@ -416,6 +499,9 @@ class EfficientNetBuilder: block = self._make_block(block_args, total_block_idx, total_block_count) blocks.append(block) + if space2depth == 1: + space2depth = 2 + # stash feature module name and channel info for model feature extraction if extract_features: feature_info = dict( diff --git a/timm/models/beit.py b/timm/models/beit.py index 63b6db54..922d15e7 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -591,7 +591,7 @@ default_cfgs = generate_default_cfgs({ }) -def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): +def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('module', state_dict) # beit v2 didn't strip module @@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Beit, variant, pretrained, - pretrained_filter_fn=_beit_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index c28538bc..32630683 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -556,7 +556,7 @@ class EfficientFormer(nn.Module): return x -def _checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'stem.0.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed @@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', 4) model = build_model_with_cfg( EfficientFormer, variant, pretrained, - pretrained_filter_fn=_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 44a77506..46c4e81e 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -36,7 +36,7 @@ the models and weights open source! Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -44,10 +44,10 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct +from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct, LayerType from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._efficientnet_blocks import SqueezeExcite -from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ +from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from ._features import FeatureInfo, FeatureHooks, feature_take_indices from ._manipulate import checkpoint_seq @@ -74,21 +74,22 @@ class EfficientNet(nn.Module): def __init__( self, - block_args, - num_classes=1000, - num_features=1280, - in_chans=3, - stem_size=32, - fix_stem=False, - output_stride=32, - pad_type='', - round_chs_fn=round_channels, - act_layer=None, - norm_layer=None, - se_layer=None, - drop_rate=0., - drop_path_rate=0., - global_pool='avg' + block_args: BlockArgs, + num_classes: int = 1000, + num_features: int = 1280, + in_chans: int = 3, + stem_size: int = 32, + fix_stem: bool = False, + output_stride: int = 32, + pad_type: str = '', + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., + global_pool: str = 'avg' ): super(EfficientNet, self).__init__() act_layer = act_layer or nn.ReLU @@ -113,6 +114,7 @@ class EfficientNet(nn.Module): round_chs_fn=round_chs_fn, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, ) @@ -270,20 +272,21 @@ class EfficientNetFeatures(nn.Module): def __init__( self, - block_args, - out_indices=(0, 1, 2, 3, 4), - feature_location='bottleneck', - in_chans=3, - stem_size=32, - fix_stem=False, - output_stride=32, - pad_type='', - round_chs_fn=round_channels, - act_layer=None, - norm_layer=None, - se_layer=None, - drop_rate=0., - drop_path_rate=0. + block_args: BlockArgs, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + feature_location: str = 'bottleneck', + in_chans: int = 3, + stem_size: int = 32, + fix_stem: bool = False, + output_stride: int = 32, + pad_type: str = '', + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., ): super(EfficientNetFeatures, self).__init__() act_layer = act_layer or nn.ReLU @@ -306,6 +309,7 @@ class EfficientNetFeatures(nn.Module): round_chs_fn=round_chs_fn, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, feature_location=feature_location, @@ -879,6 +883,88 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0 return model +def _gen_efficientnet_x( + variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8, + group_size=None, version=1, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-x-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-x-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-x-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-x-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-x-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-x-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-x-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-x-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-x-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + """ + if version == 1: + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25_d1_a0', + 'r2_k3_s22_e6_i16_o24_se0.25_f1_d2_a1', + 'r2_k5_s22_e6_i24_o40_se0.25_f1_a1', + 'r3_k3_s22_e6_i40_o80_se0.25_a0', + 'r3_k5_s11_e6_i80_o112_se0.25_a0', + 'r4_k5_s22_e6_i112_o192_se0.25_a0', + 'r1_k3_s11_e6_i192_o320_se0.25_a0', + ] + elif version == 2: + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25_d1_a0', + 'r2_k3_s22_e4_i16_o24_se0.25_f1_d2_a1', + 'r2_k5_s22_e4_i24_o40_se0.25_f1_a1', + 'r3_k3_s22_e4_i40_o80_se0.25_a0', + 'r3_k5_s11_e6_i80_o112_se0.25_a0', + 'r4_k5_s22_e6_i112_o192_se0.25_a0', + 'r1_k3_s11_e6_i192_o320_se0.25_a0', + ] + """ + if version == 1: + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25_d1'], + ['er_r2_k3_s2_e6_c24_se0.25_nre'], + ['er_r2_k5_s2_e6_c40_se0.25_nre'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + else: + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25_d1'], + ['er_r2_k3_s2_e4_c24_se0.25_nre'], + ['er_r2_k5_s2_e4_c40_se0.25_nre'], + ['ir_r3_k3_s2_e4_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + act_layer=resolve_act_layer(kwargs, 'silu'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Small model. @@ -1072,6 +1158,7 @@ default_cfgs = generate_default_cfgs({ input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), 'efficientnet_b3_g8_gn.untrained': _cfg( input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_blur_b0.untrained': _cfg(), 'efficientnet_es.ra_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', @@ -1768,6 +1855,17 @@ def efficientnet_b3_g8_gn(pretrained=False, **kwargs) -> EfficientNet: return model +@register_model +def efficientnet_blur_b0(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B0 w/ BlurPool """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_blur_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, + aa_layer='blurpc', **kwargs + ) + return model + + @register_model def efficientnet_es(pretrained=False, **kwargs) -> EfficientNet: """ EfficientNet-Edge Small. """ @@ -2277,6 +2375,31 @@ def tf_efficientnetv2_b3(pretrained=False, **kwargs) -> EfficientNet: return model +@register_model +def efficientnet_x_b3(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_x( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_x_b5(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B5 """ + model = _gen_efficientnet_x( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_h_b5(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B5 """ + model = _gen_efficientnet_x( + 'efficientnet_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs) + return model + + @register_model def mixnet_s(pretrained=False, **kwargs) -> EfficientNet: """Creates a MixNet Small model. diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 74b6cc28..ef7ec3c9 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -7,7 +7,7 @@ # import os from functools import partial -from typing import Tuple, Optional, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \ ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -40,19 +41,19 @@ class MobileOneBlock(nn.Module): """ def __init__( - self, - in_chs: int, - out_chs: int, - kernel_size: int, - stride: int = 1, - dilation: int = 1, - group_size: int = 0, - inference_mode: bool = False, - use_se: bool = False, - use_act: bool = True, - use_scale_branch: bool = True, - num_conv_branches: int = 1, - act_layer: nn.Module = nn.GELU, + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + inference_mode: bool = False, + use_se: bool = False, + use_act: bool = True, + use_scale_branch: bool = True, + num_conv_branches: int = 1, + act_layer: nn.Module = nn.GELU, ) -> None: """Construct a MobileOneBlock module. @@ -280,15 +281,16 @@ class ReparamLargeKernelConv(nn.Module): """ def __init__( - self, - in_chs: int, - out_chs: int, - kernel_size: int, - stride: int, - group_size: int, - small_kernel: Optional[int] = None, - inference_mode: bool = False, - act_layer: Optional[nn.Module] = None, + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int, + group_size: int, + small_kernel: Optional[int] = None, + use_se: bool = False, + act_layer: Optional[nn.Module] = None, + inference_mode: bool = False, ) -> None: """Construct a ReparamLargeKernelConv module. @@ -299,8 +301,8 @@ class ReparamLargeKernelConv(nn.Module): stride: Stride size. Default: 1 group_size: Group size. Default: 1 small_kernel: Kernel size of small kernel conv branch. - inference_mode: If True, instantiates model in inference mode. Default: ``False`` act_layer: Activation module. Default: ``nn.GELU`` + inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ super(ReparamLargeKernelConv, self).__init__() self.stride = stride @@ -342,6 +344,7 @@ class ReparamLargeKernelConv(nn.Module): groups=self.groups, apply_act=False, ) + self.se = SqueezeExcite(out_chs, rd_ratio=0.25) if use_se else nn.Identity() # FIXME output of this act was not used in original impl, likely due to bug self.act = act_layer() if act_layer is not None else nn.Identity() @@ -352,6 +355,7 @@ class ReparamLargeKernelConv(nn.Module): out = self.large_conv(x) if self.small_conv is not None: out = out + self.small_conv(x) + out = self.se(out) out = self.act(out) return out @@ -472,12 +476,12 @@ class Attention(nn.Module): fused_attn: torch.jit.Final[bool] def __init__( - self, - dim: int, - head_dim: int = 32, - qkv_bias: bool = False, - attn_drop: float = 0.0, - proj_drop: float = 0.0, + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, ) -> None: """Build MHSA module that can handle 3D or 4D input tensors. @@ -535,14 +539,15 @@ class PatchEmbed(nn.Module): """Convolutional patch embedding layer.""" def __init__( - self, - patch_size: int, - stride: int, - in_chs: int, - embed_dim: int, - act_layer: nn.Module = nn.GELU, - lkc_use_act: bool = False, - inference_mode: bool = False, + self, + patch_size: int, + stride: int, + in_chs: int, + embed_dim: int, + act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, + use_se: bool = False, + inference_mode: bool = False, ) -> None: """Build patch embedding layer. @@ -562,14 +567,16 @@ class PatchEmbed(nn.Module): stride=stride, group_size=1, small_kernel=3, - inference_mode=inference_mode, + use_se=use_se, act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act + inference_mode=inference_mode, ), MobileOneBlock( in_chs=embed_dim, out_chs=embed_dim, kernel_size=1, stride=1, + use_se=False, act_layer=act_layer, inference_mode=inference_mode, ) @@ -598,11 +605,11 @@ class RepMixer(nn.Module): """ def __init__( - self, - dim, - kernel_size=3, - layer_scale_init_value=1e-5, - inference_mode: bool = False, + self, + dim, + kernel_size=3, + layer_scale_init_value=1e-5, + inference_mode: bool = False, ): """Build RepMixer Module. @@ -648,7 +655,7 @@ class RepMixer(nn.Module): if layer_scale_init_value is not None: self.layer_scale = LayerScale2d(dim, layer_scale_init_value) else: - self.layer_scale = nn.Identity + self.layer_scale = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.reparam_conv is not None: @@ -706,12 +713,12 @@ class ConvMlp(nn.Module): """Convolutional FFN Module.""" def __init__( - self, - in_chs: int, - hidden_channels: Optional[int] = None, - out_chs: Optional[int] = None, - act_layer: nn.Module = nn.GELU, - drop: float = 0.0, + self, + in_chs: int, + hidden_channels: Optional[int] = None, + out_chs: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, ) -> None: """Build convolutional FFN module. @@ -764,11 +771,11 @@ class RepConditionalPosEnc(nn.Module): """ def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - spatial_shape: Union[int, Tuple[int, int]] = (7, 7), - inference_mode=False, + self, + dim: int, + dim_out: Optional[int] = None, + spatial_shape: Union[int, Tuple[int, int]] = (7, 7), + inference_mode=False, ) -> None: """Build reparameterizable conditional positional encoding @@ -878,15 +885,15 @@ class RepMixerBlock(nn.Module): """ def __init__( - self, - dim: int, - kernel_size: int = 3, - mlp_ratio: float = 4.0, - act_layer: nn.Module = nn.GELU, - proj_drop: float = 0.0, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-5, - inference_mode: bool = False, + self, + dim: int, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, ): """Build RepMixer Block. @@ -936,14 +943,14 @@ class AttentionBlock(nn.Module): """ def __init__( - self, - dim: int, - mlp_ratio: float = 4.0, - act_layer: nn.Module = nn.GELU, - norm_layer: nn.Module = nn.BatchNorm2d, - proj_drop: float = 0.0, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-5, + self, + dim: int, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, ): """Build Attention Block. @@ -993,6 +1000,7 @@ class FastVitStage(nn.Module): depth: int, token_mixer_type: str, downsample: bool = True, + se_downsample: bool = False, down_patch_size: int = 7, down_stride: int = 2, pos_emb_layer: Optional[nn.Module] = None, @@ -1030,6 +1038,7 @@ class FastVitStage(nn.Module): stride=down_stride, in_chs=dim, embed_dim=dim_out, + use_se=se_downsample, act_layer=act_layer, lkc_use_act=lkc_use_act, inference_mode=inference_mode, @@ -1090,29 +1099,30 @@ class FastVit(nn.Module): """ def __init__( - self, - in_chans: int = 3, - layers: Tuple[int, ...] = (2, 2, 6, 2), - token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), - embed_dims: Tuple[int, ...] = (64, 128, 256, 512), - mlp_ratios: Tuple[float, ...] = (4,) * 4, - downsamples: Tuple[bool, ...] = (False, True, True, True), - repmixer_kernel_size: int = 3, - num_classes: int = 1000, - pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, - down_patch_size: int = 7, - down_stride: int = 2, - drop_rate: float = 0.0, - proj_drop_rate: float = 0.0, - drop_path_rate: float = 0.0, - layer_scale_init_value: float = 1e-5, - fork_feat: bool = False, - cls_ratio: float = 2.0, - global_pool: str = 'avg', - norm_layer: nn.Module = nn.BatchNorm2d, - act_layer: nn.Module = nn.GELU, - lkc_use_act: bool = False, - inference_mode: bool = False, + self, + in_chans: int = 3, + layers: Tuple[int, ...] = (2, 2, 6, 2), + token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), + embed_dims: Tuple[int, ...] = (64, 128, 256, 512), + mlp_ratios: Tuple[float, ...] = (4,) * 4, + downsamples: Tuple[bool, ...] = (False, True, True, True), + se_downsamples: Tuple[bool, ...] = (False, False, False, False), + repmixer_kernel_size: int = 3, + num_classes: int = 1000, + pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, + down_patch_size: int = 7, + down_stride: int = 2, + drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-5, + lkc_use_act: bool = False, + fork_feat: bool = False, + cls_ratio: float = 2.0, + global_pool: str = 'avg', + norm_layer: nn.Module = nn.BatchNorm2d, + act_layer: nn.Module = nn.GELU, + inference_mode: bool = False, ) -> None: super().__init__() self.num_classes = 0 if fork_feat else num_classes @@ -1140,6 +1150,7 @@ class FastVit(nn.Module): dim_out=embed_dims[i], depth=layers[i], downsample=downsample, + se_downsample=se_downsamples[i], down_patch_size=down_patch_size, down_stride=down_stride, pos_emb_layer=pos_embs[i], @@ -1160,6 +1171,7 @@ class FastVit(nn.Module): scale *= 2 self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) + self.num_stages = len(self.stages) self.num_features = prev_dim # For segmentation and detection, extract intermediate output @@ -1236,6 +1248,66 @@ class FastVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = self.num_stages - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + feat_idx = 0 + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.final_conv(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: # input embedding x = self.stem(x) @@ -1297,8 +1369,7 @@ default_cfgs = generate_default_cfgs({ "fastvit_ma36.apple_in1k": _cfg( hf_hub_id='timm/', - crop_pct=0.95 - ), + crop_pct=0.95), "fastvit_t8.apple_dist_in1k": _cfg( hf_hub_id='timm/'), @@ -1318,15 +1389,111 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', crop_pct=0.95 ), + + "fastvit_mci0.apple_mclip": _cfg( + #hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), + "fastvit_mci1.apple_mclip": _cfg( + # hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), + "fastvit_mci2.apple_mclip": _cfg( + # hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), }) +def checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'stem.0.conv_kxk.0.conv.weight' in state_dict: + return state_dict # non-original checkpoint, no remapping needed + + state_dict = state_dict.get('state_dict', state_dict) + if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: + # remap MobileCLIP checkpoints + prefix = 'image_encoder.model.' + else: + prefix = '' + + import re + import bisect + + # find stage ends by locating downsample layers + stage_ends = [] + for k, v in state_dict.items(): + match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k) + if match: + stage_ends.append(int(match.group(2))) + stage_ends = list(sorted(set(stage_ends))) + + out_dict = {} + for k, v in state_dict.items(): + if prefix: + if prefix not in k: + continue + k = k.replace(prefix, '') + + # remap renamed layers + k = k.replace('patch_embed', 'stem') + k = k.replace('rbr_conv', 'conv_kxk') + k = k.replace('rbr_scale', 'conv_scale') + k = k.replace('rbr_skip', 'identity') + k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet + k = k.replace('lkb_origin', 'large_conv') + k = k.replace('convffn', 'mlp') + k = k.replace('se.reduce', 'se.fc1') + k = k.replace('se.expand', 'se.fc2') + k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k) + if k.endswith('layer_scale'): + k = k.replace('layer_scale', 'layer_scale.gamma') + k = k.replace('dist_head', 'head_dist') + if k.startswith('head.'): + if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear): + # if CLIP projection, map to head.fc w/ bias = zeros + k = k.replace('head.proj', 'head.fc.weight') + v = v.T + out_dict['head.fc.bias'] = torch.zeros(v.shape[0]) + else: + k = k.replace('head.', 'head.fc.') + + # remap flat sequential network to stages + match = re.match(r'^network\.(\d+)', k) + stage_idx, net_idx = None, None + if match: + net_idx = int(match.group(1)) + stage_idx = bisect.bisect_right(stage_ends, net_idx) + if stage_idx is not None: + net_prefix = f'network.{net_idx}' + stage_prefix = f'stages.{stage_idx}' + if net_prefix + '.proj' in k: + k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj') + elif net_prefix + '.pe' in k: + k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc') + else: + k = k.replace(net_prefix, stage_prefix + '.blocks') + + out_dict[k] = v + return out_dict + + def _create_fastvit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) model = build_model_with_cfg( FastVit, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) @@ -1419,3 +1586,48 @@ def fastvit_ma36(pretrained=False, **kwargs): token_mixers=("repmixer", "repmixer", "repmixer", "attention") ) return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci0(pretrained=False, **kwargs): + """Instantiate MCi0 model variant.""" + model_args = dict( + layers=(2, 6, 10, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + lkc_use_act=True, + ) + return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci1(pretrained=False, **kwargs): + """Instantiate MCi1 model variant.""" + model_args = dict( + layers=(4, 12, 20, 4), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + lkc_use_act=True, + ) + return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci2(pretrained=False, **kwargs): + """Instantiate MCi2 model variant.""" + model_args = dict( + layers=(4, 12, 24, 4), + embed_dims=(80, 160, 320, 640), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + lkc_use_act=True, + ) + return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 2b4053e0..b25d87ba 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -40,6 +40,7 @@ class MobileNetV3(nn.Module): * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class) * FBNet-V3 - https://arxiv.org/abs/2006.02049 * LCNet - https://arxiv.org/abs/2109.15099 + * MobileNet-V4 - https://arxiv.org/abs/2404.10518 """ def __init__( @@ -51,14 +52,17 @@ class MobileNetV3(nn.Module): fix_stem: bool = False, num_features: int = 1280, head_bias: bool = True, - pad_type: PadType = '', + head_norm: bool = False, + pad_type: str = '', act_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, se_layer: Optional[LayerType] = None, se_from_exp: bool = True, round_chs_fn: Callable = round_channels, drop_rate: float = 0., drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, global_pool: str = 'avg', ): """ @@ -73,11 +77,13 @@ class MobileNetV3(nn.Module): pad_type: Type of padding to use for convolution layers. act_layer: Type of activation layer. norm_layer: Type of normalization layer. + aa_layer: Type of anti-aliasing layer. se_layer: Type of Squeeze-and-Excite layer. se_from_exp: If True, calculate SE channel reduction from expanded mid channels. round_chs_fn: Callable to round number of filters based on depth multiplier. drop_rate: Dropout rate. drop_path_rate: Stochastic depth rate. + layer_scale_init_value: Enable layer scale on compatible blocks if not None. global_pool: Type of pooling to use for global pooling features of the FC head. """ super(MobileNetV3, self).__init__() @@ -104,8 +110,10 @@ class MobileNetV3(nn.Module): se_from_exp=se_from_exp, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features @@ -115,8 +123,16 @@ class MobileNetV3(nn.Module): # Head + Pooling self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) num_pooled_chs = head_chs * self.global_pool.feat_mult() - self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) - self.act2 = act_layer(inplace=True) + if head_norm: + # mobilenet-v4 post-pooling PW conv is followed by a norm+act layer + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type) # never bias + self.norm_head = norm_act_layer(self.num_features) + self.act2 = nn.Identity() + else: + # mobilenet-v3 and others only have an activation after final PW conv + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.norm_head = nn.Identity() + self.act2 = act_layer(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() @@ -125,7 +141,7 @@ class MobileNetV3(nn.Module): def as_sequential(self): layers = [self.conv_stem, self.bn1] layers.extend(self.blocks) - layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([self.global_pool, self.conv_head, self.norm_head, self.act2]) layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) @@ -224,8 +240,10 @@ class MobileNetV3(nn.Module): self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 if max_index < len(self.blocks): self.conv_head = nn.Identity() + self.norm_head = nn.Identity() if prune_head: self.conv_head = nn.Identity() + self.norm_head = nn.Identity() self.reset_classifier(0, '') return take_indices @@ -241,6 +259,7 @@ class MobileNetV3(nn.Module): def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: x = self.global_pool(x) x = self.conv_head(x) + x = self.norm_head(x) x = self.act2(x) x = self.flatten(x) if pre_logits: @@ -276,9 +295,11 @@ class MobileNetV3Features(nn.Module): se_from_exp: bool = True, act_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, se_layer: Optional[LayerType] = None, drop_rate: float = 0., drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, ): """ Args: @@ -297,6 +318,7 @@ class MobileNetV3Features(nn.Module): se_layer: Type of Squeeze-and-Excite layer. drop_rate: Dropout rate. drop_path_rate: Stochastic depth rate. + layer_scale_init_value: Enable layer scale on compatible blocks if not None. """ super(MobileNetV3Features, self).__init__() act_layer = act_layer or nn.ReLU @@ -320,8 +342,10 @@ class MobileNetV3Features(nn.Module): se_from_exp=se_from_exp, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, feature_location=feature_location, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) @@ -370,7 +394,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV if 'feature_cfg' in kwargs or 'feature_cls' in kwargs: features_mode = 'cfg' else: - kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'head_norm', 'global_pool') model_cls = MobileNetV3Features features_mode = 'cls' @@ -622,6 +646,252 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = return model +def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: + """Creates a MobileNet-V4 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + num_features = 1280 + if 'hybrid' in variant: + layer_scale_init_value = 1e-5 + if 'medium' in variant: + stem_size = 32 + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + [ + 'er_r1_k3_s2_e4_c48' # FusedIB (EdgeResidual) + ], + # stage 1, 56x56 in + [ + 'uir_r1_a3_k5_s2_e4_c80', # ExtraDW + 'uir_r1_a3_k3_s1_e2_c80', # ExtraDW + ], + # stage 2, 28x28 in + [ + 'uir_r1_a3_k5_s2_e6_c160', # ExtraDW + 'uir_r1_a0_k0_s1_e2_c160', # FFN + 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW + 'uir_r1_a3_k5_s1_e4_c160', # ExtraDW + 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample + 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW + 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample + 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt + 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample + 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW + 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample + 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt + ], + # stage 3, 14x14in + [ + 'uir_r1_a5_k5_s2_e6_c256', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW + 'uir_r2_a3_k5_s1_e4_c256', # ExtraDW + 'uir_r1_a0_k0_s1_e2_c256', # FFN + 'uir_r1_a3_k5_s1_e2_c256', # ExtraDW + 'uir_r1_a0_k0_s1_e2_c256', # FFN + 'uir_r1_a0_k0_s1_e4_c256', # FFN + 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt + 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW + 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt + 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt + ], + # stage 4, 7x7 in + [ + 'cn_r1_k1_s1_c960' # Conv + ], + ] + elif 'large' in variant: + stem_size = 24 + act_layer = resolve_act_layer(kwargs, 'gelu') + arch_def = [ + # stage 0, 112x112 in + [ + 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual) + ], + # stage 1, 56x56 in + [ + 'uir_r1_a3_k5_s2_e4_c96', # ExtraDW + 'uir_r1_a3_k3_s1_e4_c96', # ExtraDW + ], + # stage 2, 28x28 in + [ + 'uir_r1_a3_k5_s2_e4_c192', # ExtraDW + 'uir_r3_a3_k3_s1_e4_c192', # ExtraDW + 'uir_r1_a3_k5_s1_e4_c192', # ExtraDW + 'uir_r2_a5_k3_s1_e4_c192', # ExtraDW + 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample + 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW + 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample + 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW + 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample + 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW + 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample + 'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt + ], + # stage 3, 14x14in + [ + 'uir_r4_a5_k5_s2_e4_c512', # ExtraDW + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW + 'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt + 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c512', # ExtraDW + 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + ], + # stage 4, 7x7 in + [ + 'cn_r1_k1_s1_c960', # Conv + ], + ] + else: + assert False, f'Unknown variant {variant}.' + else: + layer_scale_init_value = None + if 'small' in variant: + stem_size = 32 + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + [ + 'cn_r1_k3_s2_e1_c32', # Conv + 'cn_r1_k1_s1_e1_c32', # Conv + ], + # stage 1, 56x56 in + [ + 'cn_r1_k3_s2_e1_c96', # Conv + 'cn_r1_k1_s1_e1_c64', # Conv + ], + # stage 2, 28x28 in + [ + 'uir_r1_a5_k5_s2_e3_c96', # ExtraDW + 'uir_r4_a0_k3_s1_e2_c96', # IR + 'uir_r1_a3_k0_s1_e4_c96', # ConvNeXt + ], + # stage 3, 14x14 in + [ + 'uir_r1_a3_k3_s2_e6_c128', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c128', # ExtraDW + 'uir_r1_a0_k5_s1_e4_c128', # IR + 'uir_r1_a0_k5_s1_e3_c128', # IR + 'uir_r2_a0_k3_s1_e4_c128', # IR + ], + # stage 4, 7x7 in + [ + 'cn_r1_k1_s1_c960', # Conv + ], + ] + elif 'medium' in variant: + stem_size = 32 + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + [ + 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual) + ], + # stage 1, 56x56 in + [ + 'uir_r1_a3_k5_s2_e4_c80', # ExtraDW + 'uir_r1_a3_k3_s1_e2_c80', # ExtraDW + ], + # stage 2, 28x28 in + [ + 'uir_r1_a3_k5_s2_e6_c160', # ExtraDW + 'uir_r2_a3_k3_s1_e4_c160', # ExtraDW + 'uir_r1_a3_k5_s1_e4_c160', # ExtraDW + 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW + 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt + 'uir_r1_a0_k0_s1_e2_c160', # ExtraDW + 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt + ], + # stage 3, 14x14in + [ + 'uir_r1_a5_k5_s2_e6_c256', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW + 'uir_r2_a3_k5_s1_e4_c256', # ExtraDW + 'uir_r1_a0_k0_s1_e4_c256', # FFN + 'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt + 'uir_r1_a3_k5_s1_e2_c256', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW + 'uir_r2_a0_k0_s1_e4_c256', # FFN + 'uir_r1_a5_k0_s1_e2_c256', # ConvNeXt + ], + # stage 4, 7x7 in + [ + 'cn_r1_k1_s1_c960', # Conv + ], + ] + elif 'large' in variant: + stem_size = 24 + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + [ + 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual) + ], + # stage 1, 56x56 in + [ + 'uir_r1_a3_k5_s2_e4_c96', # ExtraDW + 'uir_r1_a3_k3_s1_e4_c96', # ExtraDW + ], + # stage 2, 28x28 in + [ + 'uir_r1_a3_k5_s2_e4_c192', # ExtraDW + 'uir_r3_a3_k3_s1_e4_c192', # ExtraDW + 'uir_r1_a3_k5_s1_e4_c192', # ExtraDW + 'uir_r5_a5_k3_s1_e4_c192', # ExtraDW + 'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt + ], + # stage 3, 14x14in + [ + 'uir_r4_a5_k5_s2_e4_c512', # ExtraDW + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW + 'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt + 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c512', # ExtraDW + 'uir_r3_a5_k0_s1_e4_c512', # ConvNeXt + + ], + # stage 4, 7x7 in + [ + 'cn_r1_k1_s1_c960', # Conv + ], + ] + else: + assert False, f'Unknown variant {variant}.' + + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, + head_norm=True, + num_features=num_features, + stem_size=stem_size, + fix_stem=channel_multiplier < 1.0, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=act_layer, + layer_scale_init_value=layer_scale_init_value, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + def _cfg(url: str = '', **kwargs): return { @@ -725,6 +995,52 @@ default_cfgs = generate_default_cfgs({ interpolation='bicubic', ), "lcnet_150.untrained": _cfg(), + + 'mobilenetv4_conv_small': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv4_conv_medium.r224': _cfg( + # hf_hub_id='timm/', + crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_conv_medium.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_conv_large.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_conv_large.r384': _cfg( + # hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'), + + 'mobilenetv4_hybrid_small': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv4_hybrid_medium.r224': _cfg( + # hf_hub_id='timm/', + crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_medium.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_large.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_large.r384': _cfg( + # hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'), + + # experimental + 'mobilenetv4_conv_aa_medium.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_conv_blur_medium.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_medium_075': _cfg( + # hf_hub_id='timm/', + crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_large_075.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), }) @@ -881,6 +1197,69 @@ def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3: return model +@register_model +def mobilenetv4_conv_small(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 """ + model = _gen_mobilenet_v4('mobilenetv4_conv_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_conv_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 """ + model = _gen_mobilenet_v4('mobilenetv4_conv_medium', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_conv_large(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 """ + model = _gen_mobilenet_v4('mobilenetv4_conv_large', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid """ + model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid""" + model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_conv_aa_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 w/ AvgPool AA """ + model = _gen_mobilenet_v4('mobilenetv4_conv_aa_medium', 1.0, pretrained=pretrained, aa_layer='avg', **kwargs) + return model + + +@register_model +def mobilenetv4_conv_blur_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Conv w/ Blur AA """ + model = _gen_mobilenet_v4('mobilenetv4_conv_blur_medium', 1.0, pretrained=pretrained, aa_layer='blurpc', **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid """ + model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid""" + model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 0.75, pretrained=pretrained, **kwargs) + return model + + register_model_deprecations(__name__, { 'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k', 'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k', diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 1d9c6842..90ebfe7a 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -403,7 +403,7 @@ class PyramidVisionTransformerV2(nn.Module): return x -def _checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'patch_embed.proj.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed @@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs): PyramidVisionTransformerV2, variant, pretrained, - pretrained_filter_fn=_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs, ) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 53dfab9c..15f16997 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \ - get_attn, get_act_layer, get_norm_layer, create_classifier + get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -31,15 +31,6 @@ def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int: return padding -def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module: - if not aa_layer or not enable: - return nn.Identity() - if issubclass(aa_layer, nn.AvgPool2d): - return aa_layer(stride) - else: - return aa_layer(channels=channels, stride=stride) - - class BasicBlock(nn.Module): expansion = 1 diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 2dd4754a..e3f1b8f2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -409,6 +409,7 @@ class VisionTransformer(nn.Module): qk_norm: bool = False, init_values: Optional[float] = None, class_token: bool = True, + pos_embed: str = 'learn', no_embed_class: bool = False, reg_tokens: int = 0, pre_norm: bool = False, @@ -460,6 +461,7 @@ class VisionTransformer(nn.Module): super().__init__() assert global_pool in ('', 'avg', 'token', 'map') assert class_token or global_pool != 'token' + assert pos_embed in ('', 'none', 'learn') use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU @@ -494,7 +496,10 @@ class VisionTransformer(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens - self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + if not pos_embed or pos_embed == 'none': + self.pos_embed = None + else: + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( @@ -556,7 +561,8 @@ class VisionTransformer(nn.Module): def init_weights(self, mode: str = '') -> None: assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - trunc_normal_(self.pos_embed, std=.02) + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) @@ -583,6 +589,8 @@ class VisionTransformer(nn.Module): @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: self.grad_checkpointing = enable + if hasattr(self.patch_embed, 'set_grad_checkpointing'): + self.patch_embed.set_grad_checkpointing(enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: @@ -600,6 +608,9 @@ class VisionTransformer(nn.Module): self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.pos_embed is None: + return x.view(x.shape[0], -1, x.shape[-1]) + if self.dynamic_img_size: B, H, W, C = x.shape pos_embed = resample_abs_pos_embed( @@ -1066,10 +1077,13 @@ def checkpoint_filter_fn( # IJEPA, vit in an 'encoder' submodule state_dict = state_dict['encoder'] prefix = 'module.' - elif 'visual.trunk.pos_embed' in state_dict: + elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict: # OpenCLIP model with timm vision encoder - # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) prefix = 'visual.trunk.' + if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear): + # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + out_dict['head.weight'] = state_dict['visual.head.proj.weight'] + out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) if prefix: # filter on & remove prefix string from keys diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 25dd9c27..0c690c35 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -15,18 +15,20 @@ Hacked together by / Copyright 2020, Ross Wightman """ import math from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to +from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to + +from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model, register_model_deprecations from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem -from .vision_transformer import _create_vision_transformer, VisionTransformer +from .vision_transformer import VisionTransformer class HybridEmbed(nn.Module): @@ -38,14 +40,15 @@ class HybridEmbed(nn.Module): def __init__( self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - feature_ratio=None, - in_chans=3, - embed_dim=768, - bias=True, + backbone: nn.Module, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + in_chans: int = 3, + embed_dim: int = 768, + bias: bool = True, + proj: bool = True, flatten: bool = True, output_fmt: Optional[str] = None, strict_img_size: bool = True, @@ -95,7 +98,18 @@ class HybridEmbed(nn.Module): self.strict_img_size = strict_img_size self.dynamic_img_pad = dynamic_img_pad - self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + if proj: + self.proj = nn.Conv2d( + feature_dim, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + ) + else: + assert feature_dim == embed_dim,\ + f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.' + self.proj = nn.Identity() def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: total_reduction = ( @@ -116,6 +130,13 @@ class HybridEmbed(nn.Module): else: return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1] + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + if hasattr(self.backbone, 'set_grad_checkpointing'): + self.backbone.set_grad_checkpointing(enable=enable) + elif hasattr(self.backbone, 'grad_checkpointing'): + self.backbone.grad_checkpointing = enable + def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): @@ -139,24 +160,35 @@ class HybridEmbedWithSize(nn.Module): """ def __init__( self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - in_chans=3, - embed_dim=768, + backbone: nn.Module, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + in_chans: int = 3, + embed_dim: int = 768, bias=True, + proj=True, ): super().__init__( backbone=backbone, img_size=img_size, patch_size=patch_size, feature_size=feature_size, + feature_ratio=feature_ratio, in_chans=in_chans, embed_dim=embed_dim, bias=bias, + proj=proj, ) + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + if hasattr(self.backbone, 'set_grad_checkpointing'): + self.backbone.set_grad_checkpointing(enable=enable) + elif hasattr(self.backbone, 'grad_checkpointing'): + self.backbone.grad_checkpointing = enable + def forward(self, x) -> Tuple[torch.Tensor, List[int]]: x = self.backbone(x) if isinstance(x, (list, tuple)): @@ -165,10 +197,43 @@ class HybridEmbedWithSize(nn.Module): return x.flatten(2).transpose(1, 2), x.shape[-2:] -def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - embed_layer = partial(HybridEmbed, backbone=backbone) - kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set - return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) +class ConvStem(nn.Sequential): + def __init__( + self, + in_chans: int = 3, + depth: int = 3, + channels: Union[int, Tuple[int, ...]] = 64, + kernel_size: Union[int, Tuple[int, ...]] = 3, + stride: Union[int, Tuple[int, ...]] = (2, 2, 2), + padding: Union[str, int, Tuple[int, ...]] = "", + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + ): + super().__init__() + if isinstance(channels, int): + # a default tiered channel strategy + channels = tuple([channels // 2**i for i in range(depth)][::-1]) + + kernel_size = to_ntuple(depth)(kernel_size) + padding = to_ntuple(depth)(padding) + assert depth == len(stride) == len(kernel_size) == len(channels) + + in_chs = in_chans + for i in range(len(channels)): + last_conv = i == len(channels) - 1 + self.add_module(f'{i}', ConvNormAct( + in_chs, + channels[i], + kernel_size=kernel_size[i], + stride=stride[i], + padding=padding[i], + bias=last_conv, + apply_norm=not last_conv, + apply_act=not last_conv, + norm_layer=norm_layer, + act_layer=act_layer, + )) + in_chs = channels[i] def _resnetv2(layers=(3, 4, 9), **kwargs): @@ -186,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): return backbone +def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'): + out = {} + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = k.replace(prefix, '') + k = k.replace('patch_emb.', 'patch_embed.backbone.') + k = k.replace('block.conv', 'conv') + k = k.replace('block.norm', 'bn') + k = k.replace('post_transformer_norm.', 'norm.') + k = k.replace('pre_norm_mha.0', 'norm1') + k = k.replace('pre_norm_mha.1', 'attn') + k = k.replace('pre_norm_ffn.0', 'norm2') + k = k.replace('pre_norm_ffn.1', 'mlp.fc1') + k = k.replace('pre_norm_ffn.4', 'mlp.fc2') + k = k.replace('qkv_proj.', 'qkv.') + k = k.replace('out_proj.', 'proj.') + k = k.replace('transformer.', 'blocks.') + if k == 'pos_embed.pos_embed.pos_embed': + k = 'pos_embed' + v = v.squeeze(0) + if 'classifier.proj' in k: + bias_k = k.replace('classifier.proj', 'head.bias') + k = k.replace('classifier.proj', 'head.weight') + v = v.T + out[bias_k] = torch.zeros(v.shape[0]) + out[k] = v + return out + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + interpolation: str = 'bicubic', + antialias: bool = True, +) -> Dict[str, torch.Tensor]: + from .vision_transformer import checkpoint_filter_fn as _filter_fn + + if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: + state_dict = _convert_mobileclip(state_dict, model) + + return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias) + + +def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', 3) + embed_args = embed_args or {} + embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args) + kwargs.setdefault('embed_layer', embed_layer) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + + def _cfg(url='', **kwargs): return { 'url': url, @@ -260,6 +385,17 @@ default_cfgs = generate_default_cfgs({ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 'vit_base_resnet50d_224.untrained': _cfg( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + + 'vit_base_mci_224.apple_mclip': _cfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt', + num_classes=512, + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv', + ), + 'vit_base_mci_224.apple_mclip_lt': _cfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt', + num_classes=512, + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv', + ), }) @@ -407,6 +543,26 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer: return model +@register_model +def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer: + """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. + """ + backbone = ConvStem( + channels=(768//4, 768//4, 768), + stride=(4, 2, 2), + kernel_size=(4, 2, 2), + padding=0, + in_chans=kwargs.get('in_chans', 3), + act_layer=nn.GELU, + ) + model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True) + model = _create_vision_transformer_hybrid( + 'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False), + pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + register_model_deprecations(__name__, { 'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k', 'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k', diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py new file mode 100644 index 00000000..6e0c28f0 --- /dev/null +++ b/timm/models/vitamin.py @@ -0,0 +1,603 @@ +""" ViTamin + +Paper: Designing Scalable Vison Models in the Vision-Language Era +A family of model weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf + +@inproceedings{chen2024vitamin, + title={ViTamin: Designing Scalable Vision Models in the Vision-language Era}, + author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, Alan and Chen, Liang-Chieh}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2024} +} + +Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin + +Modifications and timm support by Jieneng Chen 2024 + +Reference: +https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py +https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py +""" + +import math +from dataclasses import dataclass, field +from functools import partial +from typing import Optional, Union, Tuple + +import torch +import torch.nn as nn + +from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \ + make_divisible, DropPath +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model, generate_default_cfgs +from .vision_transformer import VisionTransformer, checkpoint_filter_fn +from .vision_transformer_hybrid import HybridEmbed + + +@dataclass +class VitConvCfg: + expand_ratio: float = 4.0 + expand_output: bool = True # calculate expansion channels from output (vs input chs) + kernel_size: int = 3 + group_size: int = 1 # 1 == depthwise + pre_norm_act: bool = False # activation after pre-norm + stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' + pool_type: str = 'avg2' + downsample_pool_type: str = 'avg2' + act_layer: str = 'gelu' # stem & stage 1234 + norm_layer: str = '' + norm_eps: float = 1e-5 + down_shortcut: Optional[bool] = True + mlp: str = 'mlp' + + +@dataclass +class VitCfg: + embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) + depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) + stem_width: int = 64 + conv_cfg: VitConvCfg = field(default_factory=VitConvCfg) + head_type: str = "" + + +def _init_conv(module, name, scheme=''): + if isinstance(module, nn.Conv2d): + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +class Stem(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: str = 'gelu', + norm_layer: str = 'layernorm2d', + norm_eps: float = 1e-6, + bias: bool = True, + ): + super().__init__() + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs + + self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias) + self.norm1 = norm_act_layer(out_chs) + self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias) + + named_apply(_init_conv, self) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + return x + + +class Downsample2d(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + pool_type: str = 'avg2', + bias: bool = True, + ): + super().__init__() + self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) + + if dim != dim_out: + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv + else: + self.expand = nn.Identity() + + def forward(self, x): + x = self.pool(x) # spatial downsample + x = self.expand(x) # expand chs + return x + + +class StridedConv(nn.Module): + """ downsample 2d as well + """ + def __init__( + self, + kernel_size=3, + stride=2, + padding=1, + in_chans=3, + embed_dim=768 + ): + super().__init__() + norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + self.norm = norm_layer(in_chans) # affine over C + + def forward(self, x): + x = self.norm(x) + x = self.proj(x) + return x + + +class MbConvLNBlock(nn.Module): + """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) + """ + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + drop_path: float = 0., + kernel_size: int = 3, + norm_layer: str = 'layernorm2d', + norm_eps: float = 1e-6, + act_layer: str = 'gelu', + expand_ratio: float = 4.0, + ): + super(MbConvLNBlock, self).__init__() + self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs + mid_chs = make_divisible(out_chs * expand_ratio) + prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True) + elif in_chs != out_chs: + self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True) + else: + self.shortcut = nn.Identity() + + self.pre_norm = prenorm_act_layer(in_chs, apply_act=False) + self.down = nn.Identity() + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True) + self.act1 = create_act_layer(act_layer, inplace=True) + self.conv2_kxk = create_conv2d( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) + self.act2 = create_act_layer(act_layer, inplace=True) + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.pre_norm(x) + x = self.down(x) # nn.Identity() + + # 1x1 expansion conv & act + x = self.conv1_1x1(x) + x = self.act1(x) + + # (strided) depthwise 3x3 conv & act + x = self.conv2_kxk(x) + x = self.act2(x) + + # 1x1 linear projection to output width + x = self.conv3_1x1(x) + x = self.drop_path(x) + shortcut + + return x + + +class MbConvStages(nn.Module): + """ MobileConv for stage 1 and stage 2 of ViTamin + """ + def __init__( + self, + cfg: VitCfg, + img_size: Union[int, Tuple[int, int]] = 224, # place holder + in_chans: int = 3, + ): + super().__init__() + self.grad_checkpointing = False + + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + ) + + stages = [] + self.num_stages = len(cfg.embed_dim) + for s, dim in enumerate(cfg.embed_dim[:2]): # stage + stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width + blocks = [ + MbConvLNBlock( + in_chs = stage_in_chs if d==0 else dim, + out_chs = dim, + stride = 2 if d == 0 else 1, + ) + for d in range(cfg.depths[s]) + ] + stages += [nn.Sequential(*blocks)] + self.stages = nn.Sequential(*stages) + + self.pool = StridedConv( + stride=2, + in_chans=cfg.embed_dim[1], + embed_dim=cfg.embed_dim[2] + ) + + def forward(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + x = self.pool(x) + return x + + +class GeGluMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features, + act_layer = 'gelu', + drop = 0.0, + ): + super().__init__() + norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + + self.norm = norm_layer(in_features) + self.w0 = nn.Linear(in_features, hidden_features) + self.act = create_act_layer(act_layer) + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(hidden_features, in_features) + + def forward(self, x): + x = self.norm(x) + x = self.act(self.w0(x)) * self.w1(x) + x = self.w2(x) + return x + + +def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs): + out_indices = kwargs.pop('out_indices', 3) + assert embed_cfg is not None + backbone = MbConvStages(cfg=embed_cfg) + kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD, + 'first_conv': 'patch_embed.backbone.stem.conv1', + 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'vitamin_small_224.datacomp1b_clip_ltt': _cfg( + hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384), + 'vitamin_small_224.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-S', num_classes=384), + 'vitamin_base_224.datacomp1b_clip_ltt': _cfg( + hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768), + 'vitamin_base_224.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-B', num_classes=768), + 'vitamin_large_224.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=768), + 'vitamin_large_256.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=768, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_large_336.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=768, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_large_384.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=768, + input_size=(3, 384, 384), crop_pct=1.0), + 'vitamin_large2_224.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-224px', num_classes=1024), + 'vitamin_large2_256.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_large2_336.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_large2_384.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024, + input_size=(3, 384, 384), crop_pct=1.0), + 'vitamin_xlarge_256.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_xlarge_336.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_xlarge_384.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152, + input_size=(3, 384, 384), crop_pct=1.0), +}) + + +@register_model +def vitamin_small_224(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(64, 128, 384), + depths=(2, 4, 1), + stem_width=64, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg + ) + model = _create_vitamin('vitamin_small_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_base_224(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(128, 256, 768), + depths=(2, 4, 1), + stem_width=128, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_base_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large_224(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg, + ) + model = _create_vitamin('vitamin_large_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg + ) + model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large2_224(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg, + ) + model = _create_vitamin('vitamin_large2_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large2_256(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large2_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large2_336(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg + ) + model = _create_vitamin('vitamin_large2_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large2_384(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large2_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg=VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin( + 'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model \ No newline at end of file