From c7a20cec1387b11c38653788df3816f90abe0315 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 21 Aug 2023 13:03:54 -0700 Subject: [PATCH] Begin adding FastViT --- timm/models/fastvit.py | 1221 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1221 insertions(+) create mode 100644 timm/models/fastvit.py diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py new file mode 100644 index 00000000..f35b8a86 --- /dev/null +++ b/timm/models/fastvit.py @@ -0,0 +1,1221 @@ +# +# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main +# +# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved. +# +import copy +import os +from functools import partial +from typing import List, Tuple, Optional, Union + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, trunc_normal_ +from ._registry import register_model +from .byobnet import MobileOneBlock + + +class ReparamLargeKernelConv(nn.Module): + """Building Block of RepLKNet + + This class defines overparameterized large kernel conv block + introduced in `RepLKNet `_ + + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + """ + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int, + groups: int, + small_kernel: int, + inference_mode: bool = False, + act_layer: nn.Module = nn.GELU(), + ) -> None: + """Construct a ReparamLargeKernelConv module. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + kernel_size: Kernel size of the large kernel conv branch. + stride: Stride size. Default: 1 + groups: Group number. Default: 1 + small_kernel: Kernel size of small kernel conv branch. + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + act_layer: Activation module. Default: ``nn.GELU`` + """ + super(ReparamLargeKernelConv, self).__init__() + + self.stride = stride + self.groups = groups + self.in_chs = in_chs + self.out_chs = out_chs + self.act_layer = act_layer + + self.kernel_size = kernel_size + self.small_kernel = small_kernel + self.padding = kernel_size // 2 + if inference_mode: + self.lkb_reparam = nn.Conv2d( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + dilation=1, + groups=groups, + bias=True, + ) + else: + self.lkb_origin = self._conv_bn( + kernel_size=kernel_size, padding=self.padding + ) + if small_kernel is not None: + assert ( + small_kernel <= kernel_size + ), "The kernel size for re-param cannot be larger than the large kernel!" + self.small_conv = self._conv_bn( + kernel_size=small_kernel, padding=small_kernel // 2 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply forward pass.""" + if hasattr(self, "lkb_reparam"): + out = self.lkb_reparam(x) + else: + out = self.lkb_origin(x) + if hasattr(self, "small_conv"): + out += self.small_conv(x) + + self.act_layer(out) + return out + + def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, "small_conv"): + small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) + eq_b += small_b + eq_k += nn.functional.pad( + small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 + ) + return eq_k, eq_b + + def reparameterize(self) -> None: + """ + Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + eq_k, eq_b = self.get_kernel_bias() + self.lkb_reparam = nn.Conv2d( + in_chs=self.in_chs, + out_chs=self.out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.lkb_origin.conv.dilation, + groups=self.groups, + bias=True, + ) + + self.lkb_reparam.weight.data = eq_k + self.lkb_reparam.bias.data = eq_b + self.__delattr__("lkb_origin") + if hasattr(self, "small_conv"): + self.__delattr__("small_conv") + + @staticmethod + def _fuse_bn( + conv: torch.Tensor, bn: nn.BatchNorm2d + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with conv layer. + + Args: + conv: Convolutional kernel weights. + bn: Batchnorm 2d layer. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential: + """Helper method to construct conv-batchnorm layers. + + Args: + kernel_size: Size of the convolution kernel. + padding: Zero-padding size. + + Returns: + A nn.Sequential Conv-BN module. + """ + mod_list = nn.Sequential() + mod_list.add_module( + "conv", + nn.Conv2d( + in_chs=self.in_chs, + out_chs=self.out_chs, + kernel_size=kernel_size, + stride=self.stride, + padding=padding, + groups=self.groups, + bias=False, + ), + ) + mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_chs)) + return mod_list + + +def convolutional_stem( + in_chs: int, out_chs: int, inference_mode: bool = False +) -> nn.Sequential: + """Build convolutional stem with MobileOne blocks. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + + Returns: + nn.Sequential object with stem elements. + """ + return nn.Sequential( + MobileOneBlock( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=3, + stride=2, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + MobileOneBlock( + in_chs=out_chs, + out_chs=out_chs, + kernel_size=3, + stride=2, + group_size=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + MobileOneBlock( + in_chs=out_chs, + out_chs=out_chs, + kernel_size=1, + stride=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + ) + + +class Attention(nn.Module): + """Multi-headed Self Attention module. + + Source modified from: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + """ + + def __init__( + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Build MHSA module that can handle 3D or 4D input tensors. + + Args: + dim: Number of embedding dimensions. + head_dim: Number of hidden dimensions per head. Default: ``32`` + qkv_bias: Use bias or not. Default: ``False`` + attn_drop: Dropout rate for attention tensor. + proj_drop: Dropout rate for projection tensor. + """ + super().__init__() + assert dim % head_dim == 0, "dim should be divisible by head_dim" + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = x.shape + B, C, H, W = shape + N = H * W + if len(shape) == 4: + x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # trick here to make q@k.t more stable + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + if len(shape) == 4: + x = x.transpose(-2, -1).reshape(B, C, H, W) + + return x + + +class PatchEmbed(nn.Module): + """Convolutional patch embedding layer.""" + + def __init__( + self, + patch_size: int, + stride: int, + in_chs: int, + embed_dim: int, + inference_mode: bool = False, + ) -> None: + """Build patch embedding layer. + + Args: + patch_size: Patch size for embedding computation. + stride: Stride for convolutional embedding layer. + in_chs: Number of channels of input tensor. + embed_dim: Number of embedding dimensions. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + """ + super().__init__() + block = list() + block.append( + ReparamLargeKernelConv( + in_chs=in_chs, + out_chs=embed_dim, + kernel_size=patch_size, + stride=stride, + groups=in_chs, + small_kernel=3, + inference_mode=inference_mode, + ) + ) + block.append( + MobileOneBlock( + in_chs=embed_dim, + out_chs=embed_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ) + ) + self.proj = nn.Sequential(*block) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x + + +class RepMixer(nn.Module): + """Reparameterizable token mixer. + + For more details, please refer to our paper: + `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_ + """ + + def __init__( + self, + dim, + kernel_size=3, + use_layer_scale=True, + layer_scale_init_value=1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Module. + + Args: + dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. + kernel_size: Kernel size for spatial mixing. Default: 3 + use_layer_scale: If True, learnable layer scale is used. Default: ``True`` + layer_scale_init_value: Initial value for layer scale. Default: 1e-5 + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + """ + super().__init__() + self.dim = dim + self.kernel_size = kernel_size + self.inference_mode = inference_mode + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_chs=self.dim, + out_chs=self.dim, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + groups=self.dim, + bias=True, + ) + else: + self.norm = MobileOneBlock( + dim, + dim, + kernel_size, + group_size=1, + use_act=False, + use_scale_branch=False, + num_conv_branches=0, + ) + self.mixer = MobileOneBlock( + dim, + dim, + kernel_size, + group_size=1, + use_act=False, + ) + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter( + layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "reparam_conv"): + x = self.reparam_conv(x) + return x + else: + if self.use_layer_scale: + x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) + else: + x = x + self.mixer(x) - self.norm(x) + return x + + def reparameterize(self) -> None: + """Reparameterize mixer and norm into a single + convolutional layer for efficient inference. + """ + if self.inference_mode: + return + + self.mixer.reparameterize() + self.norm.reparameterize() + + if self.use_layer_scale: + w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight + ) + b = torch.squeeze(self.layer_scale) * ( + self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + ) + else: + w = ( + self.mixer.id_tensor + + self.mixer.reparam_conv.weight + - self.norm.reparam_conv.weight + ) + b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + + self.reparam_conv = nn.Conv2d( + in_chs=self.dim, + out_chs=self.dim, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + groups=self.dim, + bias=True, + ) + self.reparam_conv.weight.data = w + self.reparam_conv.bias.data = b + + for para in self.parameters(): + para.detach_() + self.__delattr__("mixer") + self.__delattr__("norm") + if self.use_layer_scale: + self.__delattr__("layer_scale") + + +class ConvMlp(nn.Module): + """Convolutional FFN Module.""" + + def __init__( + self, + in_chs: int, + hidden_channels: Optional[int] = None, + out_chs: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Build convolutional FFN module. + + Args: + in_chs: Number of input channels. + hidden_channels: Number of channels after expansion. Default: None + out_chs: Number of output channels. Default: None + act_layer: Activation layer. Default: ``GELU`` + drop: Dropout rate. Default: ``0.0``. + """ + super().__init__() + out_chs = out_chs or in_chs + hidden_channels = hidden_channels or in_chs + self.conv = nn.Sequential() + self.conv.add_module( + "conv", + nn.Conv2d( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=7, + padding=3, + groups=in_chs, + bias=False, + ), + ) + self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs)) + self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1) + self.drop = nn.Dropout(drop) + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class RepCPE(nn.Module): + """Implementation of conditional positional encoding. + + For more details refer to paper: + `Conditional Positional Encodings for Vision Transformers `_ + + In our implementation, we can reparameterize this module to eliminate a skip connection. + """ + + def __init__( + self, + in_chs: int, + embed_dim: int = 768, + spatial_shape: Union[int, Tuple[int, int]] = (7, 7), + inference_mode=False, + ) -> None: + """Build reparameterizable conditional positional encoding + + Args: + in_chs: Number of input channels. + embed_dim: Number of embedding dimensions. Default: 768 + spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + super(RepCPE, self).__init__() + if isinstance(spatial_shape, int): + spatial_shape = tuple([spatial_shape] * 2) + assert isinstance(spatial_shape, Tuple), ( + f'"spatial_shape" must by a sequence or int, ' + f"get {type(spatial_shape)} instead." + ) + assert len(spatial_shape) == 2, ( + f'Length of "spatial_shape" should be 2, ' + f"got {len(spatial_shape)} instead." + ) + + self.spatial_shape = spatial_shape + self.embed_dim = embed_dim + self.in_chs = in_chs + self.groups = embed_dim + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_chs=self.in_chs, + out_chs=self.embed_dim, + kernel_size=self.spatial_shape, + stride=1, + padding=int(self.spatial_shape[0] // 2), + groups=self.embed_dim, + bias=True, + ) + else: + self.pe = nn.Conv2d( + in_chs, + embed_dim, + spatial_shape, + 1, + int(spatial_shape[0] // 2), + bias=True, + groups=embed_dim, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "reparam_conv"): + x = self.reparam_conv(x) + return x + else: + x = self.pe(x) + x + return x + + def reparameterize(self) -> None: + # Build equivalent Id tensor + input_dim = self.in_chs // self.groups + kernel_value = torch.zeros( + ( + self.in_chs, + input_dim, + self.spatial_shape[0], + self.spatial_shape[1], + ), + dtype=self.pe.weight.dtype, + device=self.pe.weight.device, + ) + for i in range(self.in_chs): + kernel_value[ + i, + i % input_dim, + self.spatial_shape[0] // 2, + self.spatial_shape[1] // 2, + ] = 1 + id_tensor = kernel_value + + # Reparameterize Id tensor and conv + w_final = id_tensor + self.pe.weight + b_final = self.pe.bias + + # Introduce reparam conv + self.reparam_conv = nn.Conv2d( + in_chs=self.in_chs, + out_chs=self.embed_dim, + kernel_size=self.spatial_shape, + stride=1, + padding=int(self.spatial_shape[0] // 2), + groups=self.embed_dim, + bias=True, + ) + self.reparam_conv.weight.data = w_final + self.reparam_conv.bias.data = b_final + + for para in self.parameters(): + para.detach_() + self.__delattr__("pe") + + +class RepMixerBlock(nn.Module): + """Implementation of Metaformer block with RepMixer as token mixer. + + For more details on Metaformer structure, please refer to: + `MetaFormer Is Actually What You Need for Vision `_ + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + drop_path: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Block. + + Args: + dim: Number of embedding dimensions. + kernel_size: Kernel size for repmixer. Default: 3 + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + use_layer_scale: Flag to turn on layer scale. Default: ``True`` + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + + super().__init__() + + self.token_mixer = RepMixer( + dim, + kernel_size=kernel_size, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + + assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( + mlp_ratio + ) + mlp_hidden_dim = int(dim * mlp_ratio) + self.convffn = ConvMlp( + in_chs=dim, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Drop Path + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # Layer Scale + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter( + layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True + ) + + def forward(self, x): + if self.use_layer_scale: + x = self.token_mixer(x) + x = x + self.drop_path(self.layer_scale * self.convffn(x)) + else: + x = self.token_mixer(x) + x = x + self.drop_path(self.convffn(x)) + return x + + +class AttentionBlock(nn.Module): + """Implementation of metaformer block with MHSA as token mixer. + + For more details on Metaformer structure, please refer to: + `MetaFormer Is Actually What You Need for Vision `_ + """ + + def __init__( + self, + dim: int, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + drop: float = 0.0, + drop_path: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + ): + """Build Attention Block. + + Args: + dim: Number of embedding dimensions. + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` + drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + use_layer_scale: Flag to turn on layer scale. Default: ``True`` + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + """ + + super().__init__() + + self.norm = norm_layer(dim) + self.token_mixer = Attention(dim=dim) + + assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( + mlp_ratio + ) + mlp_hidden_dim = int(dim * mlp_ratio) + self.convffn = ConvMlp( + in_chs=dim, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Drop path + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # Layer Scale + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True + ) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True + ) + + def forward(self, x): + if self.use_layer_scale: + x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) + x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) + else: + x = x + self.drop_path(self.token_mixer(self.norm(x))) + x = x + self.drop_path(self.convffn(x)) + return x + + +def basic_blocks( + dim: int, + block_index: int, + num_blocks: List[int], + token_mixer_type: str, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + inference_mode=False, +) -> nn.Sequential: + """Build FastViT blocks within a stage. + + Args: + dim: Number of embedding dimensions. + block_index: block index. + num_blocks: List containing number of blocks per stage. + token_mixer_type: Token mixer type. + kernel_size: Kernel size for repmixer. + mlp_ratio: MLP expansion ratio. + act_layer: Activation layer. + norm_layer: Normalization layer. + drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + use_layer_scale: Flag to turn on layer scale regularization. + layer_scale_init_value: Layer scale value at initialization. + inference_mode: Flag to instantiate block in inference mode. + + Returns: + nn.Sequential object of all the blocks within the stage. + """ + blocks = [] + for block_idx in range(num_blocks[block_index]): + block_dpr = ( + drop_path_rate + * (block_idx + sum(num_blocks[:block_index])) + / (sum(num_blocks) - 1) + ) + if token_mixer_type == "repmixer": + blocks.append( + RepMixerBlock( + dim, + kernel_size=kernel_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + ) + elif token_mixer_type == "attention": + blocks.append( + AttentionBlock( + dim, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + ) + ) + else: + raise ValueError( + "Token mixer type: {} not supported".format(token_mixer_type) + ) + blocks = nn.Sequential(*blocks) + + return blocks + + +class FastVit(nn.Module): + """ + This class implements `FastViT architecture `_ + """ + + def __init__( + self, + layers, + token_mixers: Tuple[str, ...], + embed_dims=None, + mlp_ratios=None, + downsamples=None, + repmixer_kernel_size=3, + norm_layer: nn.Module = nn.BatchNorm2d, + act_layer: nn.Module = nn.GELU, + num_classes=1000, + pos_embs=None, + down_patch_size=7, + down_stride=2, + drop_rate=0.0, + drop_path_rate=0.0, + use_layer_scale=True, + layer_scale_init_value=1e-5, + fork_feat=False, + init_cfg=None, + pretrained=None, + cls_ratio=2.0, + inference_mode=False, + **kwargs, + ) -> None: + + super().__init__() + + if not fork_feat: + self.num_classes = num_classes + self.fork_feat = fork_feat + + if pos_embs is None: + pos_embs = [None] * len(layers) + + # Convolutional stem + self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode) + + # Build the main stages of the network architecture + network = [] + for i in range(len(layers)): + # Add position embeddings if requested + if pos_embs[i] is not None: + network.append(pos_embs[i]( + embed_dims[i], + embed_dims[i], + inference_mode=inference_mode, + )) + stage = basic_blocks( + embed_dims[i], + i, + layers, + token_mixer_type=token_mixers[i], + kernel_size=repmixer_kernel_size, + mlp_ratio=mlp_ratios[i], + act_layer=act_layer, + norm_layer=norm_layer, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + network.append(stage) + if i >= len(layers) - 1: + break + + # Patch merging/downsampling between stages. + if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: + network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + in_chs=embed_dims[i], + embed_dim=embed_dims[i + 1], + inference_mode=inference_mode, + ) + ) + + self.network = nn.ModuleList(network) + + # For segmentation and detection, extract intermediate output + if self.fork_feat: + # add a norm layer for each output + self.out_indices = [0, 2, 4, 6] + for i_emb, i_layer in enumerate(self.out_indices): + if i_emb == 0 and os.environ.get("FORK_LAST3", None): + """For RetinaNet, `start_level=1`. The first norm layer will not used. + cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` + """ + layer = nn.Identity() + else: + layer = norm_layer(embed_dims[i_emb]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + else: + # Classifier head + self.gap = nn.AdaptiveAvgPool2d(output_size=1) + self.conv_exp = MobileOneBlock( + in_chs=embed_dims[-1], + out_chs=int(embed_dims[-1] * cls_ratio), + kernel_size=3, + stride=1, + group_size=1, + inference_mode=inference_mode, + use_se=True, + num_conv_branches=1, + ) + self.head = ( + nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) + if num_classes > 0 + else nn.Identity() + ) + + self.apply(self.cls_init_weights) + self.init_cfg = copy.deepcopy(init_cfg) + + # load pre-trained model + if self.fork_feat and (self.init_cfg is not None or pretrained is not None): + self.init_weights() + + def cls_init_weights(self, m: nn.Module) -> None: + """Init. for classification""" + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @staticmethod + def _scrub_checkpoint(checkpoint, model): + sterile_dict = {} + for k1, v1 in checkpoint.items(): + if k1 not in model.state_dict(): + continue + if v1.shape == model.state_dict()[k1].shape: + sterile_dict[k1] = v1 + return sterile_dict + + def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + return x + + def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if self.fork_feat and idx in self.out_indices: + norm_layer = getattr(self, f"norm{idx}") + x_out = norm_layer(x) + outs.append(x_out) + if self.fork_feat: + # output the features of four stages for dense prediction + return outs + # output only the features of last layer for image classification + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + if self.fork_feat: + # output features of four stages for dense prediction + return x + # for image classification + x = self.conv_exp(x) + x = self.gap(x) + x = x.view(x.size(0), -1) + cls_out = self.head(x) + return cls_out + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 256, 256), + "pool_size": None, + "crop_pct": 0.95, + "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + "fastvit_t": _cfg(crop_pct=0.9), + "fastvit_s": _cfg(crop_pct=0.9), + "fastvit_m": _cfg(crop_pct=0.95), +} + + +@register_model +def fastvit_t8(pretrained=False, **kwargs): + """Instantiate FastViT-T8 model variant.""" + layers = [2, 2, 4, 2] + embed_dims = [48, 96, 192, 384] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, True, True, True] + token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") + model = FastVit( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_t"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_t12(pretrained=False, **kwargs): + """Instantiate FastViT-T12 model variant.""" + layers = [2, 2, 6, 2] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, True, True, True] + token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") + model = FastVit( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_t"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_s12(pretrained=False, **kwargs): + """Instantiate FastViT-S12 model variant.""" + layers = [2, 2, 6, 2] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") + model = FastVit( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_s"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_sa12(pretrained=False, **kwargs): + """Instantiate FastViT-SA12 model variant.""" + layers = [2, 2, 6, 2] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastVit( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_s"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_sa24(pretrained=False, **kwargs): + """Instantiate FastViT-SA24 model variant.""" + layers = [4, 4, 12, 4] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastVit( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_s"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_sa36(pretrained=False, **kwargs): + """Instantiate FastViT-SA36 model variant.""" + layers = [6, 6, 18, 6] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastVit( + layers, + embed_dims=embed_dims, + token_mixers=token_mixers, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + layer_scale_init_value=1e-6, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_m"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def fastvit_ma36(pretrained=False, **kwargs): + """Instantiate FastViT-MA36 model variant.""" + layers = [6, 6, 18, 6] + embed_dims = [76, 152, 304, 608] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, True, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastVit( + layers, + embed_dims=embed_dims, + token_mixers=token_mixers, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + layer_scale_init_value=1e-6, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_m"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model \ No newline at end of file