# # For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main # # Original work is copyright (C) 2023 Apple Inc. All Rights Reserved. # import copy import os from functools import partial from typing import List, Tuple, Optional, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn from ._builder import build_model_with_cfg from ._registry import register_model, generate_default_cfgs def num_groups(group_size, channels): if not group_size: # 0 or None return 1 # normal conv with 1 group else: # NOTE group_size == 1 -> depthwise conv assert channels % group_size == 0 return channels // group_size class MobileOneBlock(nn.Module): """MobileOne building block. This block has a multi-branched architecture at train-time and plain-CNN style architecture at inference time For more details, please refer to our paper: `An Improved One millisecond Mobile Backbone` - https://arxiv.org/pdf/2206.04040.pdf """ def __init__( self, in_chs: int, out_chs: int, kernel_size: int, stride: int = 1, dilation: int = 1, group_size: int = 0, inference_mode: bool = False, use_se: bool = False, use_act: bool = True, use_scale_branch: bool = True, num_conv_branches: int = 1, act_layer: nn.Module = nn.GELU, ) -> None: """Construct a MobileOneBlock module. Args: in_chs: Number of channels in the input. out_chs: Number of channels produced by the block. kernel_size: Size of the convolution kernel. stride: Stride size. dilation: Kernel dilation factor. group_size: Convolution group size. inference_mode: If True, instantiates model in inference mode. use_se: Whether to use SE-ReLU activations. use_act: Whether to use activation. Default: ``True`` use_scale_branch: Whether to use scale branch. Default: ``True`` num_conv_branches: Number of linear conv branches. """ super(MobileOneBlock, self).__init__() self.inference_mode = inference_mode self.groups = num_groups(group_size, in_chs) self.stride = stride self.dilation = dilation self.kernel_size = kernel_size self.in_chs = in_chs self.out_chs = out_chs self.num_conv_branches = num_conv_branches # Check if SE-ReLU is requested self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity() if inference_mode: self.reparam_conv = create_conv2d( in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=self.groups, bias=True, ) else: # Re-parameterizable skip connection self.reparam_conv = None self.rbr_skip = ( nn.BatchNorm2d(num_features=in_chs) if out_chs == in_chs and stride == 1 else None ) # Re-parameterizable conv branches if num_conv_branches > 0: self.rbr_conv = nn.ModuleList([ ConvNormAct( self.in_chs, self.out_chs, kernel_size=kernel_size, stride=self.stride, groups=self.groups, apply_act=False, ) for _ in range(self.num_conv_branches) ]) else: self.rbr_conv = None # Re-parameterizable scale branch self.rbr_scale = None if kernel_size > 1 and use_scale_branch: self.rbr_scale = ConvNormAct( self.in_chs, self.out_chs, kernel_size=1, stride=self.stride, groups=self.groups, apply_act=False ) self.act = act_layer() if use_act else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" # Inference mode forward pass. if self.reparam_conv is not None: return self.act(self.se(self.reparam_conv(x))) # Multi-branched train-time forward pass. # Skip branch output identity_out = 0 if self.rbr_skip is not None: identity_out = self.rbr_skip(x) # Scale branch output scale_out = 0 if self.rbr_scale is not None: scale_out = self.rbr_scale(x) # Other branches out = scale_out + identity_out if self.rbr_conv is not None: for rc in self.rbr_conv: out += rc(x) return self.act(self.se(out)) def reparameterize(self): """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like structure for inference. """ if self.reparam_conv is not None: return kernel, bias = self._get_kernel_bias() self.reparam_conv = create_conv2d( in_channels=self.in_chs, out_channels=self.out_chs, kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation, groups=self.groups, bias=True, ) self.reparam_conv.weight.data = kernel self.reparam_conv.bias.data = bias # Delete un-used branches for para in self.parameters(): para.detach_() self.__delattr__("rbr_conv") self.__delattr__("rbr_scale") if hasattr(self, "rbr_skip"): self.__delattr__("rbr_skip") self.inference_mode = True def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: """Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 Returns: Tuple of (kernel, bias) after fusing branches. """ # get weights and bias of scale branch kernel_scale = 0 bias_scale = 0 if self.rbr_scale is not None: kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) # Pad scale branch kernel to match conv branch kernel size. pad = self.kernel_size // 2 kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) # get weights and bias of skip branch kernel_identity = 0 bias_identity = 0 if self.rbr_skip is not None: kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) # get weights and bias of conv branches kernel_conv = 0 bias_conv = 0 if self.rbr_conv is not None: for ix in range(self.num_conv_branches): _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) kernel_conv += _kernel bias_conv += _bias kernel_final = kernel_conv + kernel_scale + kernel_identity bias_final = bias_conv + bias_scale + bias_identity return kernel_final, bias_final def _fuse_bn_tensor( self, branch: Union[nn.Sequential, nn.BatchNorm2d] ) -> Tuple[torch.Tensor, torch.Tensor]: """Method to fuse batchnorm layer with preceeding conv layer. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 Args: branch: Sequence of ops to be fused. Returns: Tuple of (kernel, bias) after fusing batchnorm. """ if isinstance(branch, nn.Sequential): kernel = branch.conv.weight running_mean = branch.bn.running_mean running_var = branch.bn.running_var gamma = branch.bn.weight beta = branch.bn.bias eps = branch.bn.eps else: assert isinstance(branch, nn.BatchNorm2d) if not hasattr(self, "id_tensor"): input_dim = self.in_chs // self.groups kernel_value = torch.zeros( (self.in_chs, input_dim, self.kernel_size, self.kernel_size), dtype=branch.weight.dtype, device=branch.weight.device, ) for i in range(self.in_chs): kernel_value[ i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2 ] = 1 self.id_tensor = kernel_value kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var gamma = branch.weight beta = branch.bias eps = branch.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std class ReparamLargeKernelConv(nn.Module): """Building Block of RepLKNet This class defines overparameterized large kernel conv block introduced in `RepLKNet `_ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch """ def __init__( self, in_chs: int, out_chs: int, kernel_size: int, stride: int, group_size: int, small_kernel: Optional[int] = None, inference_mode: bool = False, act_layer: Optional[nn.Module] = None, ) -> None: """Construct a ReparamLargeKernelConv module. Args: in_chs: Number of input channels. out_chs: Number of output channels. kernel_size: Kernel size of the large kernel conv branch. stride: Stride size. Default: 1 group_size: Group size. Default: 1 small_kernel: Kernel size of small kernel conv branch. inference_mode: If True, instantiates model in inference mode. Default: ``False`` act_layer: Activation module. Default: ``nn.GELU`` """ super(ReparamLargeKernelConv, self).__init__() self.stride = stride self.groups = num_groups(group_size, in_chs) self.in_chs = in_chs self.out_chs = out_chs self.kernel_size = kernel_size self.small_kernel = small_kernel if inference_mode: self.lkb_reparam = create_conv2d( in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=1, groups=self.groups, bias=True, ) else: self.lkb_reparam = None self.lkb_origin = ConvNormAct( in_chs, out_chs, kernel_size=kernel_size, stride=self.stride, groups=self.groups, apply_act=False, ) if small_kernel is not None: assert ( small_kernel <= kernel_size ), "The kernel size for re-param cannot be larger than the large kernel!" self.small_conv = ConvNormAct( in_chs, out_chs, kernel_size=small_kernel, stride=self.stride, groups=self.groups, apply_act=False, ) # FIXME output of this act was not used in original impl, likely due to bug self.act = act_layer() if act_layer is not None else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.lkb_reparam is not None: out = self.lkb_reparam(x) else: out = self.lkb_origin(x) if self.small_conv is not None: out = out + self.small_conv(x) out = self.act(out) return out def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: """Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepLKNet-pytorch Returns: Tuple of (kernel, bias) after fusing branches. """ eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) if hasattr(self, "small_conv"): small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) eq_b += small_b eq_k += nn.functional.pad( small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 ) return eq_k, eq_b def reparameterize(self) -> None: """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like structure for inference. """ eq_k, eq_b = self.get_kernel_bias() self.lkb_reparam = create_conv2d( self.in_chs, self.out_chs, kernel_size=self.kernel_size, stride=self.stride, dilation=self.lkb_origin.conv.dilation, groups=self.groups, bias=True, ) self.lkb_reparam.weight.data = eq_k self.lkb_reparam.bias.data = eq_b self.__delattr__("lkb_origin") if hasattr(self, "small_conv"): self.__delattr__("small_conv") @staticmethod def _fuse_bn( conv: torch.Tensor, bn: nn.BatchNorm2d ) -> Tuple[torch.Tensor, torch.Tensor]: """Method to fuse batchnorm layer with conv layer. Args: conv: Convolutional kernel weights. bn: Batchnorm 2d layer. Returns: Tuple of (kernel, bias) after fusing batchnorm. """ kernel = conv.weight running_mean = bn.running_mean running_var = bn.running_var gamma = bn.weight beta = bn.bias eps = bn.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std def convolutional_stem( in_chs: int, out_chs: int, inference_mode: bool = False ) -> nn.Sequential: """Build convolutional stem with MobileOne blocks. Args: in_chs: Number of input channels. out_chs: Number of output channels. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` Returns: nn.Sequential object with stem elements. """ return nn.Sequential( MobileOneBlock( in_chs=in_chs, out_chs=out_chs, kernel_size=3, stride=2, inference_mode=inference_mode, ), MobileOneBlock( in_chs=out_chs, out_chs=out_chs, kernel_size=3, stride=2, group_size=1, inference_mode=inference_mode, ), MobileOneBlock( in_chs=out_chs, out_chs=out_chs, kernel_size=1, stride=1, inference_mode=inference_mode, ), ) class Attention(nn.Module): """Multi-headed Self Attention module. Source modified from: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ fused_attn: torch.jit.Final[bool] def __init__( self, dim: int, head_dim: int = 32, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: """Build MHSA module that can handle 3D or 4D input tensors. Args: dim: Number of embedding dimensions. head_dim: Number of hidden dimensions per head. Default: ``32`` qkv_bias: Use bias or not. Default: ``False`` attn_drop: Dropout rate for attention tensor. proj_drop: Dropout rate for projection tensor. """ super().__init__() assert dim % head_dim == 0, "dim should be divisible by head_dim" self.head_dim = head_dim self.num_heads = dim // head_dim self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape N = H * W x = x.flatten(2).transpose(-2, -1) # (B, N, C) qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) x = x.transpose(-2, -1).reshape(B, C, H, W) return x class PatchEmbed(nn.Module): """Convolutional patch embedding layer.""" def __init__( self, patch_size: int, stride: int, in_chs: int, embed_dim: int, inference_mode: bool = False, ) -> None: """Build patch embedding layer. Args: patch_size: Patch size for embedding computation. stride: Stride for convolutional embedding layer. in_chs: Number of channels of input tensor. embed_dim: Number of embedding dimensions. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` """ super().__init__() self.proj = nn.Sequential( ReparamLargeKernelConv( in_chs=in_chs, out_chs=embed_dim, kernel_size=patch_size, stride=stride, group_size=1, small_kernel=3, inference_mode=inference_mode, act_layer=None, # activation was not used in original impl ), MobileOneBlock( in_chs=embed_dim, out_chs=embed_dim, kernel_size=1, stride=1, inference_mode=inference_mode, ) ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x class RepMixer(nn.Module): """Reparameterizable token mixer. For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_ """ def __init__( self, dim, kernel_size=3, use_layer_scale=True, layer_scale_init_value=1e-5, inference_mode: bool = False, ): """Build RepMixer Module. Args: dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. kernel_size: Kernel size for spatial mixing. Default: 3 use_layer_scale: If True, learnable layer scale is used. Default: ``True`` layer_scale_init_value: Initial value for layer scale. Default: 1e-5 inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ super().__init__() self.dim = dim self.kernel_size = kernel_size self.inference_mode = inference_mode if inference_mode: self.reparam_conv = nn.Conv2d( self.dim, self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) else: self.reparam_conv = None self.norm = MobileOneBlock( dim, dim, kernel_size, group_size=1, use_act=False, use_scale_branch=False, num_conv_branches=0, ) self.mixer = MobileOneBlock( dim, dim, kernel_size, group_size=1, use_act=False, ) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.reparam_conv is not None: x = self.reparam_conv(x) return x else: if self.use_layer_scale: x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) else: x = x + self.mixer(x) - self.norm(x) return x def reparameterize(self) -> None: """Reparameterize mixer and norm into a single convolutional layer for efficient inference. """ if self.inference_mode: return self.mixer.reparameterize() self.norm.reparameterize() if self.use_layer_scale: w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = torch.squeeze(self.layer_scale) * ( self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias ) else: w = ( self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias self.reparam_conv = create_conv2d( self.dim, self.dim, kernel_size=self.kernel_size, stride=1, groups=self.dim, bias=True, ) self.reparam_conv.weight.data = w self.reparam_conv.bias.data = b for para in self.parameters(): para.detach_() self.__delattr__("mixer") self.__delattr__("norm") if self.use_layer_scale: self.__delattr__("layer_scale") class ConvMlp(nn.Module): """Convolutional FFN Module.""" def __init__( self, in_chs: int, hidden_channels: Optional[int] = None, out_chs: Optional[int] = None, act_layer: nn.Module = nn.GELU, drop: float = 0.0, ) -> None: """Build convolutional FFN module. Args: in_chs: Number of input channels. hidden_channels: Number of channels after expansion. Default: None out_chs: Number of output channels. Default: None act_layer: Activation layer. Default: ``GELU`` drop: Dropout rate. Default: ``0.0``. """ super().__init__() out_chs = out_chs or in_chs hidden_channels = hidden_channels or in_chs # self.conv = nn.Sequential() # self.conv.add_module( # "conv", # nn.Conv2d( # in_chs, # out_chs, # kernel_size=7, # padding=3, # groups=in_chs, # bias=False, # ), # ) # self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs)) self.conv = ConvNormAct( in_chs, out_chs, kernel_size=7, groups=in_chs, apply_act=False, ) self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class RepCPE(nn.Module): """Implementation of conditional positional encoding. For more details refer to paper: `Conditional Positional Encodings for Vision Transformers `_ In our implementation, we can reparameterize this module to eliminate a skip connection. """ def __init__( self, dim: int, dim_out: Optional[int] = None, spatial_shape: Union[int, Tuple[int, int]] = (7, 7), inference_mode=False, ) -> None: """Build reparameterizable conditional positional encoding Args: dim: Number of input channels. dim_out: Number of embedding dimensions. Default: 768 spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super(RepCPE, self).__init__() if isinstance(spatial_shape, int): spatial_shape = tuple([spatial_shape] * 2) assert isinstance(spatial_shape, Tuple), ( f'"spatial_shape" must by a sequence or int, ' f"get {type(spatial_shape)} instead." ) assert len(spatial_shape) == 2, ( f'Length of "spatial_shape" should be 2, ' f"got {len(spatial_shape)} instead." ) self.spatial_shape = spatial_shape self.dim = dim self.dim_out = dim_out or dim self.groups = dim if inference_mode: self.reparam_conv = nn.Conv2d( self.dim, self.dim_out, kernel_size=self.spatial_shape, stride=1, padding=spatial_shape[0] // 2, groups=self.groups, bias=True, ) else: self.reparam_conv = None self.pe = nn.Conv2d( self.dim, self.dim_out, spatial_shape, 1, int(spatial_shape[0] // 2), groups=self.groups, bias=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.reparam_conv is not None: x = self.reparam_conv(x) return x else: x = self.pe(x) + x return x def reparameterize(self) -> None: # Build equivalent Id tensor input_dim = self.dim // self.groups kernel_value = torch.zeros( ( self.dim, input_dim, self.spatial_shape[0], self.spatial_shape[1], ), dtype=self.pe.weight.dtype, device=self.pe.weight.device, ) for i in range(self.dim): kernel_value[ i, i % input_dim, self.spatial_shape[0] // 2, self.spatial_shape[1] // 2, ] = 1 id_tensor = kernel_value # Reparameterize Id tensor and conv w_final = id_tensor + self.pe.weight b_final = self.pe.bias # Introduce reparam conv self.reparam_conv = nn.Conv2d( self.dim, self.dim_out, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.groups, bias=True, ) self.reparam_conv.weight.data = w_final self.reparam_conv.bias.data = b_final for para in self.parameters(): para.detach_() self.__delattr__("pe") class RepMixerBlock(nn.Module): """Implementation of Metaformer block with RepMixer as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision `_ """ def __init__( self, dim: int, kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, ): """Build RepMixer Block. Args: dim: Number of embedding dimensions. kernel_size: Kernel size for repmixer. Default: 3 mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super().__init__() self.token_mixer = RepMixer( dim, kernel_size=kernel_size, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) self.convffn = ConvMlp( in_chs=dim, hidden_channels=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, ) # Drop Path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) def forward(self, x): if self.use_layer_scale: x = self.token_mixer(x) x = x + self.drop_path(self.layer_scale * self.convffn(x)) else: x = self.token_mixer(x) x = x + self.drop_path(self.convffn(x)) return x class AttentionBlock(nn.Module): """Implementation of metaformer block with MHSA as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision `_ """ def __init__( self, dim: int, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, ): """Build Attention Block. Args: dim: Number of embedding dimensions. mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 """ super().__init__() self.norm = norm_layer(dim) self.token_mixer = Attention(dim=dim) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvMlp( in_chs=dim, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # Drop path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) else: x = x + self.drop_path(self.token_mixer(self.norm(x))) x = x + self.drop_path(self.convffn(x)) return x class FastVitStage(nn.Module): def __init__( self, dim: int, dim_out: int, depth: int, token_mixer_type: str, downsample: bool = True, down_patch_size: int = 7, down_stride: int = 2, pos_emb_layer: Optional[nn.Module] = None, kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode=False, ): """FastViT stage. Args: dim: Number of embedding dimensions. num_blocks: List containing number of blocks per stage. token_mixer_type: Token mixer type. kernel_size: Kernel size for repmixer. mlp_ratio: MLP expansion ratio. act_layer: Activation layer. norm_layer: Normalization layer. drop_rate: Dropout rate. drop_path_rate: Drop path rate. use_layer_scale: Flag to turn on layer scale regularization. layer_scale_init_value: Layer scale value at initialization. inference_mode: Flag to instantiate block in inference mode. """ super().__init__() if downsample: self.downsample = PatchEmbed( patch_size=down_patch_size, stride=down_stride, in_chs=dim, embed_dim=dim_out, inference_mode=inference_mode, ) else: assert dim == dim_out self.downsample = nn.Identity() if pos_emb_layer is not None: self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode) else: self.pos_emb = nn.Identity() blocks = [] for block_idx in range(depth): if token_mixer_type == "repmixer": blocks.append(RepMixerBlock( dim_out, kernel_size=kernel_size, mlp_ratio=mlp_ratio, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate[block_idx], use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, )) elif token_mixer_type == "attention": blocks.append(AttentionBlock( dim_out, mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=drop_path_rate[block_idx], use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, )) else: raise ValueError( "Token mixer type: {} not supported".format(token_mixer_type) ) self.blocks = nn.Sequential(*blocks) def forward(self, x): x = self.downsample(x) x = self.pos_emb(x) x = self.blocks(x) return x class FastVit(nn.Module): fork_feat: torch.jit.Final[bool] """ This class implements `FastViT architecture `_ """ def __init__( self, in_chans: int = 3, layers: Tuple[int, ...] = (2, 2, 6, 2), token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), embed_dims: Tuple[int, ...] = (64, 128, 256, 512), mlp_ratios: Tuple[float, ...] = (4,) * 4, downsamples: Tuple[bool, ...] = (False, True, True, True), repmixer_kernel_size: int = 3, num_classes: int = 1000, pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, down_patch_size: int = 7, down_stride: int = 2, drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, fork_feat: bool = False, cls_ratio: float = 2.0, norm_layer: nn.Module = nn.BatchNorm2d, act_layer: nn.Module = nn.GELU, inference_mode: bool = False, ) -> None: super().__init__() self.num_classes = 0 if fork_feat else num_classes self.fork_feat = fork_feat # Convolutional stem self.patch_embed = convolutional_stem( in_chans, embed_dims[0], inference_mode, ) # Build the main stages of the network architecture prev_dim = embed_dims[0] dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] network = [] for i in range(len(layers)): stage = FastVitStage( dim=prev_dim, dim_out=embed_dims[i], depth=layers[i], downsample=downsamples[i] or prev_dim != embed_dims[i], down_patch_size=down_patch_size, down_stride=down_stride, pos_emb_layer=pos_embs[i], token_mixer_type=token_mixers[i], kernel_size=repmixer_kernel_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, drop_path_rate=dpr[i], use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) network.append(stage) prev_dim = embed_dims[i] self.network = nn.Sequential(*network) # For segmentation and detection, extract intermediate output if self.fork_feat: # add a norm layer for each output self.out_indices = [0, 2, 4, 6] for i_emb, i_layer in enumerate(self.out_indices): if i_emb == 0 and os.environ.get("FORK_LAST3", None): """For RetinaNet, `start_level=1`. The first norm layer will not used. cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` """ layer = nn.Identity() else: layer = norm_layer(embed_dims[i_emb]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) else: # Classifier head self.gap = nn.AdaptiveAvgPool2d(output_size=1) self.conv_exp = MobileOneBlock( in_chs=embed_dims[-1], out_chs=int(embed_dims[-1] * cls_ratio), kernel_size=3, stride=1, group_size=1, inference_mode=inference_mode, use_se=True, num_conv_branches=1, ) self.head = ( nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: """Init. for classification""" if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) @staticmethod def _scrub_checkpoint(checkpoint, model): sterile_dict = {} for k1, v1 in checkpoint.items(): if k1 not in model.state_dict(): continue if v1.shape == model.state_dict()[k1].shape: sterile_dict[k1] = v1 return sterile_dict def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) return x def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: outs = [] for idx, block in enumerate(self.network): x = block(x) if self.fork_feat: if idx in self.out_indices: norm_layer = getattr(self, f"norm{idx}") x_out = norm_layer(x) outs.append(x_out) if self.fork_feat: # output the features of four stages for dense prediction return outs # output only the features of last layer for image classification return x def forward(self, x: torch.Tensor) -> torch.Tensor: # input embedding x = self.forward_embeddings(x) # through backbone x = self.forward_tokens(x) if self.fork_feat: # output features of four stages for dense prediction return x # for image classification x = self.conv_exp(x) x = self.gap(x) x = x.view(x.size(0), -1) cls_out = self.head(x) return cls_out def _cfg(url="", **kwargs): return { "url": url, "num_classes": 1000, "input_size": (3, 256, 256), "pool_size": None, "crop_pct": 0.9, "interpolation": "bicubic", "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "classifier": "head", **kwargs, } default_cfgs = generate_default_cfgs({ "fastvit_t8.apple_in1k": _cfg( url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar' ), "fastvit_t12.apple_in1k": _cfg( url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t12.pth.tar' ), "fastvit_s12.apple_in1k": _cfg( url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_s12.pth.tar'), "fastvit_sa12.apple_in1k": _cfg( url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa12.pth.tar'), "fastvit_sa24.apple_in1k": _cfg( url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa24.pth.tar'), "fastvit_sa36.apple_in1k": _cfg( url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa36.pth.tar'), "fastvit_ma36.apple_in1k": _cfg( url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_ma36.pth.tar', crop_pct=0.95 ), # "fastvit_t8.apple_dist_in1k": _cfg( # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t8.pth.tar' # ), # "fastvit_t12.apple_dist_in1k": _cfg( # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t12.pth.tar' # ), # # "fastvit_s12.apple_dist_in1k": _cfg( # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12.pth.tar'), # "fastvit_sa12.apple_dist_in1k": _cfg( # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa12.pth.tar'), # "fastvit_sa24.apple_dist_in1k": _cfg( # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa24.pth.tar'), # "fastvit_sa36.apple_dist_in1k": _cfg( # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa36.pth.tar'), # # "fastvit_ma36.apple_dist_in1k": _cfg( # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_ma36.pth.tar', # crop_pct=0.95 # ), }) def checkpoint_filter_fn(state_dict, model): # FIXME temporary for remapping state_dict = state_dict.get('state_dict', state_dict) msd = model.state_dict() out_dict = {} for ka, kb, va, vb in zip(msd.keys(), state_dict.keys(), msd.values(), state_dict.values()): if va.ndim == 4 and vb.ndim == 2: vb = vb[:, :, None, None] out_dict[ka] = vb return out_dict def _create_fastvit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) model = build_model_with_cfg( FastVit, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) return model @register_model def fastvit_t8(pretrained=False, **kwargs): """Instantiate FastViT-T8 model variant.""" model_args = dict( layers=(2, 2, 4, 2), embed_dims=(48, 96, 192, 384), mlp_ratios=(3, 3, 3, 3), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer") ) return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def fastvit_t12(pretrained=False, **kwargs): """Instantiate FastViT-T12 model variant.""" model_args = dict( layers=(2, 2, 6, 2), embed_dims=(64, 128, 256, 512), mlp_ratios=(3, 3, 3, 3), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), ) return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def fastvit_s12(pretrained=False, **kwargs): """Instantiate FastViT-S12 model variant.""" model_args = dict( layers=(2, 2, 6, 2), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), ) return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def fastvit_sa12(pretrained=False, **kwargs): """Instantiate FastViT-SA12 model variant.""" model_args = dict( layers=(2, 2, 6, 2), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def fastvit_sa24(pretrained=False, **kwargs): """Instantiate FastViT-SA24 model variant.""" model_args = dict( layers=(4, 4, 12, 4), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def fastvit_sa36(pretrained=False, **kwargs): """Instantiate FastViT-SA36 model variant.""" model_args = dict( layers=(6, 6, 18, 6), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def fastvit_ma36(pretrained=False, **kwargs): """Instantiate FastViT-MA36 model variant.""" model_args = dict( layers=(6, 6, 18, 6), embed_dims=(76, 152, 304, 608), mlp_ratios=(4, 4, 4, 4), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention") ) return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))