1221 lines
39 KiB
Python
1221 lines
39 KiB
Python
|
#
|
||
|
# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main
|
||
|
#
|
||
|
# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved.
|
||
|
#
|
||
|
import copy
|
||
|
import os
|
||
|
from functools import partial
|
||
|
from typing import List, Tuple, Optional, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||
|
from timm.layers import DropPath, trunc_normal_
|
||
|
from ._registry import register_model
|
||
|
from .byobnet import MobileOneBlock
|
||
|
|
||
|
|
||
|
class ReparamLargeKernelConv(nn.Module):
|
||
|
"""Building Block of RepLKNet
|
||
|
|
||
|
This class defines overparameterized large kernel conv block
|
||
|
introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
|
||
|
|
||
|
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_chs: int,
|
||
|
out_chs: int,
|
||
|
kernel_size: int,
|
||
|
stride: int,
|
||
|
groups: int,
|
||
|
small_kernel: int,
|
||
|
inference_mode: bool = False,
|
||
|
act_layer: nn.Module = nn.GELU(),
|
||
|
) -> None:
|
||
|
"""Construct a ReparamLargeKernelConv module.
|
||
|
|
||
|
Args:
|
||
|
in_chs: Number of input channels.
|
||
|
out_chs: Number of output channels.
|
||
|
kernel_size: Kernel size of the large kernel conv branch.
|
||
|
stride: Stride size. Default: 1
|
||
|
groups: Group number. Default: 1
|
||
|
small_kernel: Kernel size of small kernel conv branch.
|
||
|
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||
|
act_layer: Activation module. Default: ``nn.GELU``
|
||
|
"""
|
||
|
super(ReparamLargeKernelConv, self).__init__()
|
||
|
|
||
|
self.stride = stride
|
||
|
self.groups = groups
|
||
|
self.in_chs = in_chs
|
||
|
self.out_chs = out_chs
|
||
|
self.act_layer = act_layer
|
||
|
|
||
|
self.kernel_size = kernel_size
|
||
|
self.small_kernel = small_kernel
|
||
|
self.padding = kernel_size // 2
|
||
|
if inference_mode:
|
||
|
self.lkb_reparam = nn.Conv2d(
|
||
|
in_chs=in_chs,
|
||
|
out_chs=out_chs,
|
||
|
kernel_size=kernel_size,
|
||
|
stride=stride,
|
||
|
padding=self.padding,
|
||
|
dilation=1,
|
||
|
groups=groups,
|
||
|
bias=True,
|
||
|
)
|
||
|
else:
|
||
|
self.lkb_origin = self._conv_bn(
|
||
|
kernel_size=kernel_size, padding=self.padding
|
||
|
)
|
||
|
if small_kernel is not None:
|
||
|
assert (
|
||
|
small_kernel <= kernel_size
|
||
|
), "The kernel size for re-param cannot be larger than the large kernel!"
|
||
|
self.small_conv = self._conv_bn(
|
||
|
kernel_size=small_kernel, padding=small_kernel // 2
|
||
|
)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
"""Apply forward pass."""
|
||
|
if hasattr(self, "lkb_reparam"):
|
||
|
out = self.lkb_reparam(x)
|
||
|
else:
|
||
|
out = self.lkb_origin(x)
|
||
|
if hasattr(self, "small_conv"):
|
||
|
out += self.small_conv(x)
|
||
|
|
||
|
self.act_layer(out)
|
||
|
return out
|
||
|
|
||
|
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Method to obtain re-parameterized kernel and bias.
|
||
|
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
|
||
|
|
||
|
Returns:
|
||
|
Tuple of (kernel, bias) after fusing branches.
|
||
|
"""
|
||
|
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
|
||
|
if hasattr(self, "small_conv"):
|
||
|
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
|
||
|
eq_b += small_b
|
||
|
eq_k += nn.functional.pad(
|
||
|
small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
|
||
|
)
|
||
|
return eq_k, eq_b
|
||
|
|
||
|
def reparameterize(self) -> None:
|
||
|
"""
|
||
|
Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
||
|
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
||
|
architecture used at training time to obtain a plain CNN-like structure
|
||
|
for inference.
|
||
|
"""
|
||
|
eq_k, eq_b = self.get_kernel_bias()
|
||
|
self.lkb_reparam = nn.Conv2d(
|
||
|
in_chs=self.in_chs,
|
||
|
out_chs=self.out_chs,
|
||
|
kernel_size=self.kernel_size,
|
||
|
stride=self.stride,
|
||
|
padding=self.padding,
|
||
|
dilation=self.lkb_origin.conv.dilation,
|
||
|
groups=self.groups,
|
||
|
bias=True,
|
||
|
)
|
||
|
|
||
|
self.lkb_reparam.weight.data = eq_k
|
||
|
self.lkb_reparam.bias.data = eq_b
|
||
|
self.__delattr__("lkb_origin")
|
||
|
if hasattr(self, "small_conv"):
|
||
|
self.__delattr__("small_conv")
|
||
|
|
||
|
@staticmethod
|
||
|
def _fuse_bn(
|
||
|
conv: torch.Tensor, bn: nn.BatchNorm2d
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Method to fuse batchnorm layer with conv layer.
|
||
|
|
||
|
Args:
|
||
|
conv: Convolutional kernel weights.
|
||
|
bn: Batchnorm 2d layer.
|
||
|
|
||
|
Returns:
|
||
|
Tuple of (kernel, bias) after fusing batchnorm.
|
||
|
"""
|
||
|
kernel = conv.weight
|
||
|
running_mean = bn.running_mean
|
||
|
running_var = bn.running_var
|
||
|
gamma = bn.weight
|
||
|
beta = bn.bias
|
||
|
eps = bn.eps
|
||
|
std = (running_var + eps).sqrt()
|
||
|
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||
|
return kernel * t, beta - running_mean * gamma / std
|
||
|
|
||
|
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
|
||
|
"""Helper method to construct conv-batchnorm layers.
|
||
|
|
||
|
Args:
|
||
|
kernel_size: Size of the convolution kernel.
|
||
|
padding: Zero-padding size.
|
||
|
|
||
|
Returns:
|
||
|
A nn.Sequential Conv-BN module.
|
||
|
"""
|
||
|
mod_list = nn.Sequential()
|
||
|
mod_list.add_module(
|
||
|
"conv",
|
||
|
nn.Conv2d(
|
||
|
in_chs=self.in_chs,
|
||
|
out_chs=self.out_chs,
|
||
|
kernel_size=kernel_size,
|
||
|
stride=self.stride,
|
||
|
padding=padding,
|
||
|
groups=self.groups,
|
||
|
bias=False,
|
||
|
),
|
||
|
)
|
||
|
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_chs))
|
||
|
return mod_list
|
||
|
|
||
|
|
||
|
def convolutional_stem(
|
||
|
in_chs: int, out_chs: int, inference_mode: bool = False
|
||
|
) -> nn.Sequential:
|
||
|
"""Build convolutional stem with MobileOne blocks.
|
||
|
|
||
|
Args:
|
||
|
in_chs: Number of input channels.
|
||
|
out_chs: Number of output channels.
|
||
|
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||
|
|
||
|
Returns:
|
||
|
nn.Sequential object with stem elements.
|
||
|
"""
|
||
|
return nn.Sequential(
|
||
|
MobileOneBlock(
|
||
|
in_chs=in_chs,
|
||
|
out_chs=out_chs,
|
||
|
kernel_size=3,
|
||
|
stride=2,
|
||
|
inference_mode=inference_mode,
|
||
|
use_se=False,
|
||
|
num_conv_branches=1,
|
||
|
),
|
||
|
MobileOneBlock(
|
||
|
in_chs=out_chs,
|
||
|
out_chs=out_chs,
|
||
|
kernel_size=3,
|
||
|
stride=2,
|
||
|
group_size=1,
|
||
|
inference_mode=inference_mode,
|
||
|
use_se=False,
|
||
|
num_conv_branches=1,
|
||
|
),
|
||
|
MobileOneBlock(
|
||
|
in_chs=out_chs,
|
||
|
out_chs=out_chs,
|
||
|
kernel_size=1,
|
||
|
stride=1,
|
||
|
inference_mode=inference_mode,
|
||
|
use_se=False,
|
||
|
num_conv_branches=1,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
class Attention(nn.Module):
|
||
|
"""Multi-headed Self Attention module.
|
||
|
|
||
|
Source modified from:
|
||
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
dim: int,
|
||
|
head_dim: int = 32,
|
||
|
qkv_bias: bool = False,
|
||
|
attn_drop: float = 0.0,
|
||
|
proj_drop: float = 0.0,
|
||
|
) -> None:
|
||
|
"""Build MHSA module that can handle 3D or 4D input tensors.
|
||
|
|
||
|
Args:
|
||
|
dim: Number of embedding dimensions.
|
||
|
head_dim: Number of hidden dimensions per head. Default: ``32``
|
||
|
qkv_bias: Use bias or not. Default: ``False``
|
||
|
attn_drop: Dropout rate for attention tensor.
|
||
|
proj_drop: Dropout rate for projection tensor.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
||
|
self.head_dim = head_dim
|
||
|
self.num_heads = dim // head_dim
|
||
|
self.scale = head_dim**-0.5
|
||
|
|
||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||
|
self.proj = nn.Linear(dim, dim)
|
||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
shape = x.shape
|
||
|
B, C, H, W = shape
|
||
|
N = H * W
|
||
|
if len(shape) == 4:
|
||
|
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
|
||
|
qkv = (
|
||
|
self.qkv(x)
|
||
|
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
||
|
.permute(2, 0, 3, 1, 4)
|
||
|
)
|
||
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||
|
|
||
|
# trick here to make q@k.t more stable
|
||
|
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||
|
attn = attn.softmax(dim=-1)
|
||
|
attn = self.attn_drop(attn)
|
||
|
|
||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||
|
x = self.proj(x)
|
||
|
x = self.proj_drop(x)
|
||
|
if len(shape) == 4:
|
||
|
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
class PatchEmbed(nn.Module):
|
||
|
"""Convolutional patch embedding layer."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
patch_size: int,
|
||
|
stride: int,
|
||
|
in_chs: int,
|
||
|
embed_dim: int,
|
||
|
inference_mode: bool = False,
|
||
|
) -> None:
|
||
|
"""Build patch embedding layer.
|
||
|
|
||
|
Args:
|
||
|
patch_size: Patch size for embedding computation.
|
||
|
stride: Stride for convolutional embedding layer.
|
||
|
in_chs: Number of channels of input tensor.
|
||
|
embed_dim: Number of embedding dimensions.
|
||
|
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||
|
"""
|
||
|
super().__init__()
|
||
|
block = list()
|
||
|
block.append(
|
||
|
ReparamLargeKernelConv(
|
||
|
in_chs=in_chs,
|
||
|
out_chs=embed_dim,
|
||
|
kernel_size=patch_size,
|
||
|
stride=stride,
|
||
|
groups=in_chs,
|
||
|
small_kernel=3,
|
||
|
inference_mode=inference_mode,
|
||
|
)
|
||
|
)
|
||
|
block.append(
|
||
|
MobileOneBlock(
|
||
|
in_chs=embed_dim,
|
||
|
out_chs=embed_dim,
|
||
|
kernel_size=1,
|
||
|
stride=1,
|
||
|
padding=0,
|
||
|
groups=1,
|
||
|
inference_mode=inference_mode,
|
||
|
use_se=False,
|
||
|
num_conv_branches=1,
|
||
|
)
|
||
|
)
|
||
|
self.proj = nn.Sequential(*block)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x = self.proj(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class RepMixer(nn.Module):
|
||
|
"""Reparameterizable token mixer.
|
||
|
|
||
|
For more details, please refer to our paper:
|
||
|
`FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
dim,
|
||
|
kernel_size=3,
|
||
|
use_layer_scale=True,
|
||
|
layer_scale_init_value=1e-5,
|
||
|
inference_mode: bool = False,
|
||
|
):
|
||
|
"""Build RepMixer Module.
|
||
|
|
||
|
Args:
|
||
|
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
||
|
kernel_size: Kernel size for spatial mixing. Default: 3
|
||
|
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
||
|
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
||
|
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.dim = dim
|
||
|
self.kernel_size = kernel_size
|
||
|
self.inference_mode = inference_mode
|
||
|
|
||
|
if inference_mode:
|
||
|
self.reparam_conv = nn.Conv2d(
|
||
|
in_chs=self.dim,
|
||
|
out_chs=self.dim,
|
||
|
kernel_size=self.kernel_size,
|
||
|
stride=1,
|
||
|
padding=self.kernel_size // 2,
|
||
|
groups=self.dim,
|
||
|
bias=True,
|
||
|
)
|
||
|
else:
|
||
|
self.norm = MobileOneBlock(
|
||
|
dim,
|
||
|
dim,
|
||
|
kernel_size,
|
||
|
group_size=1,
|
||
|
use_act=False,
|
||
|
use_scale_branch=False,
|
||
|
num_conv_branches=0,
|
||
|
)
|
||
|
self.mixer = MobileOneBlock(
|
||
|
dim,
|
||
|
dim,
|
||
|
kernel_size,
|
||
|
group_size=1,
|
||
|
use_act=False,
|
||
|
)
|
||
|
self.use_layer_scale = use_layer_scale
|
||
|
if use_layer_scale:
|
||
|
self.layer_scale = nn.Parameter(
|
||
|
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
||
|
)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
if hasattr(self, "reparam_conv"):
|
||
|
x = self.reparam_conv(x)
|
||
|
return x
|
||
|
else:
|
||
|
if self.use_layer_scale:
|
||
|
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
||
|
else:
|
||
|
x = x + self.mixer(x) - self.norm(x)
|
||
|
return x
|
||
|
|
||
|
def reparameterize(self) -> None:
|
||
|
"""Reparameterize mixer and norm into a single
|
||
|
convolutional layer for efficient inference.
|
||
|
"""
|
||
|
if self.inference_mode:
|
||
|
return
|
||
|
|
||
|
self.mixer.reparameterize()
|
||
|
self.norm.reparameterize()
|
||
|
|
||
|
if self.use_layer_scale:
|
||
|
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
||
|
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||
|
)
|
||
|
b = torch.squeeze(self.layer_scale) * (
|
||
|
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||
|
)
|
||
|
else:
|
||
|
w = (
|
||
|
self.mixer.id_tensor
|
||
|
+ self.mixer.reparam_conv.weight
|
||
|
- self.norm.reparam_conv.weight
|
||
|
)
|
||
|
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||
|
|
||
|
self.reparam_conv = nn.Conv2d(
|
||
|
in_chs=self.dim,
|
||
|
out_chs=self.dim,
|
||
|
kernel_size=self.kernel_size,
|
||
|
stride=1,
|
||
|
padding=self.kernel_size // 2,
|
||
|
groups=self.dim,
|
||
|
bias=True,
|
||
|
)
|
||
|
self.reparam_conv.weight.data = w
|
||
|
self.reparam_conv.bias.data = b
|
||
|
|
||
|
for para in self.parameters():
|
||
|
para.detach_()
|
||
|
self.__delattr__("mixer")
|
||
|
self.__delattr__("norm")
|
||
|
if self.use_layer_scale:
|
||
|
self.__delattr__("layer_scale")
|
||
|
|
||
|
|
||
|
class ConvMlp(nn.Module):
|
||
|
"""Convolutional FFN Module."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_chs: int,
|
||
|
hidden_channels: Optional[int] = None,
|
||
|
out_chs: Optional[int] = None,
|
||
|
act_layer: nn.Module = nn.GELU,
|
||
|
drop: float = 0.0,
|
||
|
) -> None:
|
||
|
"""Build convolutional FFN module.
|
||
|
|
||
|
Args:
|
||
|
in_chs: Number of input channels.
|
||
|
hidden_channels: Number of channels after expansion. Default: None
|
||
|
out_chs: Number of output channels. Default: None
|
||
|
act_layer: Activation layer. Default: ``GELU``
|
||
|
drop: Dropout rate. Default: ``0.0``.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
out_chs = out_chs or in_chs
|
||
|
hidden_channels = hidden_channels or in_chs
|
||
|
self.conv = nn.Sequential()
|
||
|
self.conv.add_module(
|
||
|
"conv",
|
||
|
nn.Conv2d(
|
||
|
in_chs=in_chs,
|
||
|
out_chs=out_chs,
|
||
|
kernel_size=7,
|
||
|
padding=3,
|
||
|
groups=in_chs,
|
||
|
bias=False,
|
||
|
),
|
||
|
)
|
||
|
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
|
||
|
self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1)
|
||
|
self.act = act_layer()
|
||
|
self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1)
|
||
|
self.drop = nn.Dropout(drop)
|
||
|
self.apply(self._init_weights)
|
||
|
|
||
|
def _init_weights(self, m: nn.Module) -> None:
|
||
|
if isinstance(m, nn.Conv2d):
|
||
|
trunc_normal_(m.weight, std=0.02)
|
||
|
if m.bias is not None:
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x = self.conv(x)
|
||
|
x = self.fc1(x)
|
||
|
x = self.act(x)
|
||
|
x = self.drop(x)
|
||
|
x = self.fc2(x)
|
||
|
x = self.drop(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class RepCPE(nn.Module):
|
||
|
"""Implementation of conditional positional encoding.
|
||
|
|
||
|
For more details refer to paper:
|
||
|
`Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
|
||
|
|
||
|
In our implementation, we can reparameterize this module to eliminate a skip connection.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_chs: int,
|
||
|
embed_dim: int = 768,
|
||
|
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
|
||
|
inference_mode=False,
|
||
|
) -> None:
|
||
|
"""Build reparameterizable conditional positional encoding
|
||
|
|
||
|
Args:
|
||
|
in_chs: Number of input channels.
|
||
|
embed_dim: Number of embedding dimensions. Default: 768
|
||
|
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
|
||
|
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||
|
"""
|
||
|
super(RepCPE, self).__init__()
|
||
|
if isinstance(spatial_shape, int):
|
||
|
spatial_shape = tuple([spatial_shape] * 2)
|
||
|
assert isinstance(spatial_shape, Tuple), (
|
||
|
f'"spatial_shape" must by a sequence or int, '
|
||
|
f"get {type(spatial_shape)} instead."
|
||
|
)
|
||
|
assert len(spatial_shape) == 2, (
|
||
|
f'Length of "spatial_shape" should be 2, '
|
||
|
f"got {len(spatial_shape)} instead."
|
||
|
)
|
||
|
|
||
|
self.spatial_shape = spatial_shape
|
||
|
self.embed_dim = embed_dim
|
||
|
self.in_chs = in_chs
|
||
|
self.groups = embed_dim
|
||
|
|
||
|
if inference_mode:
|
||
|
self.reparam_conv = nn.Conv2d(
|
||
|
in_chs=self.in_chs,
|
||
|
out_chs=self.embed_dim,
|
||
|
kernel_size=self.spatial_shape,
|
||
|
stride=1,
|
||
|
padding=int(self.spatial_shape[0] // 2),
|
||
|
groups=self.embed_dim,
|
||
|
bias=True,
|
||
|
)
|
||
|
else:
|
||
|
self.pe = nn.Conv2d(
|
||
|
in_chs,
|
||
|
embed_dim,
|
||
|
spatial_shape,
|
||
|
1,
|
||
|
int(spatial_shape[0] // 2),
|
||
|
bias=True,
|
||
|
groups=embed_dim,
|
||
|
)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
if hasattr(self, "reparam_conv"):
|
||
|
x = self.reparam_conv(x)
|
||
|
return x
|
||
|
else:
|
||
|
x = self.pe(x) + x
|
||
|
return x
|
||
|
|
||
|
def reparameterize(self) -> None:
|
||
|
# Build equivalent Id tensor
|
||
|
input_dim = self.in_chs // self.groups
|
||
|
kernel_value = torch.zeros(
|
||
|
(
|
||
|
self.in_chs,
|
||
|
input_dim,
|
||
|
self.spatial_shape[0],
|
||
|
self.spatial_shape[1],
|
||
|
),
|
||
|
dtype=self.pe.weight.dtype,
|
||
|
device=self.pe.weight.device,
|
||
|
)
|
||
|
for i in range(self.in_chs):
|
||
|
kernel_value[
|
||
|
i,
|
||
|
i % input_dim,
|
||
|
self.spatial_shape[0] // 2,
|
||
|
self.spatial_shape[1] // 2,
|
||
|
] = 1
|
||
|
id_tensor = kernel_value
|
||
|
|
||
|
# Reparameterize Id tensor and conv
|
||
|
w_final = id_tensor + self.pe.weight
|
||
|
b_final = self.pe.bias
|
||
|
|
||
|
# Introduce reparam conv
|
||
|
self.reparam_conv = nn.Conv2d(
|
||
|
in_chs=self.in_chs,
|
||
|
out_chs=self.embed_dim,
|
||
|
kernel_size=self.spatial_shape,
|
||
|
stride=1,
|
||
|
padding=int(self.spatial_shape[0] // 2),
|
||
|
groups=self.embed_dim,
|
||
|
bias=True,
|
||
|
)
|
||
|
self.reparam_conv.weight.data = w_final
|
||
|
self.reparam_conv.bias.data = b_final
|
||
|
|
||
|
for para in self.parameters():
|
||
|
para.detach_()
|
||
|
self.__delattr__("pe")
|
||
|
|
||
|
|
||
|
class RepMixerBlock(nn.Module):
|
||
|
"""Implementation of Metaformer block with RepMixer as token mixer.
|
||
|
|
||
|
For more details on Metaformer structure, please refer to:
|
||
|
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
dim: int,
|
||
|
kernel_size: int = 3,
|
||
|
mlp_ratio: float = 4.0,
|
||
|
act_layer: nn.Module = nn.GELU,
|
||
|
drop: float = 0.0,
|
||
|
drop_path: float = 0.0,
|
||
|
use_layer_scale: bool = True,
|
||
|
layer_scale_init_value: float = 1e-5,
|
||
|
inference_mode: bool = False,
|
||
|
):
|
||
|
"""Build RepMixer Block.
|
||
|
|
||
|
Args:
|
||
|
dim: Number of embedding dimensions.
|
||
|
kernel_size: Kernel size for repmixer. Default: 3
|
||
|
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||
|
act_layer: Activation layer. Default: ``nn.GELU``
|
||
|
drop: Dropout rate. Default: 0.0
|
||
|
drop_path: Drop path rate. Default: 0.0
|
||
|
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||
|
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||
|
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||
|
"""
|
||
|
|
||
|
super().__init__()
|
||
|
|
||
|
self.token_mixer = RepMixer(
|
||
|
dim,
|
||
|
kernel_size=kernel_size,
|
||
|
use_layer_scale=use_layer_scale,
|
||
|
layer_scale_init_value=layer_scale_init_value,
|
||
|
inference_mode=inference_mode,
|
||
|
)
|
||
|
|
||
|
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
||
|
mlp_ratio
|
||
|
)
|
||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||
|
self.convffn = ConvMlp(
|
||
|
in_chs=dim,
|
||
|
hidden_channels=mlp_hidden_dim,
|
||
|
act_layer=act_layer,
|
||
|
drop=drop,
|
||
|
)
|
||
|
|
||
|
# Drop Path
|
||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||
|
|
||
|
# Layer Scale
|
||
|
self.use_layer_scale = use_layer_scale
|
||
|
if use_layer_scale:
|
||
|
self.layer_scale = nn.Parameter(
|
||
|
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.use_layer_scale:
|
||
|
x = self.token_mixer(x)
|
||
|
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
||
|
else:
|
||
|
x = self.token_mixer(x)
|
||
|
x = x + self.drop_path(self.convffn(x))
|
||
|
return x
|
||
|
|
||
|
|
||
|
class AttentionBlock(nn.Module):
|
||
|
"""Implementation of metaformer block with MHSA as token mixer.
|
||
|
|
||
|
For more details on Metaformer structure, please refer to:
|
||
|
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
dim: int,
|
||
|
mlp_ratio: float = 4.0,
|
||
|
act_layer: nn.Module = nn.GELU,
|
||
|
norm_layer: nn.Module = nn.BatchNorm2d,
|
||
|
drop: float = 0.0,
|
||
|
drop_path: float = 0.0,
|
||
|
use_layer_scale: bool = True,
|
||
|
layer_scale_init_value: float = 1e-5,
|
||
|
):
|
||
|
"""Build Attention Block.
|
||
|
|
||
|
Args:
|
||
|
dim: Number of embedding dimensions.
|
||
|
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||
|
act_layer: Activation layer. Default: ``nn.GELU``
|
||
|
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
|
||
|
drop: Dropout rate. Default: 0.0
|
||
|
drop_path: Drop path rate. Default: 0.0
|
||
|
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||
|
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||
|
"""
|
||
|
|
||
|
super().__init__()
|
||
|
|
||
|
self.norm = norm_layer(dim)
|
||
|
self.token_mixer = Attention(dim=dim)
|
||
|
|
||
|
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
||
|
mlp_ratio
|
||
|
)
|
||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||
|
self.convffn = ConvMlp(
|
||
|
in_chs=dim,
|
||
|
hidden_channels=mlp_hidden_dim,
|
||
|
act_layer=act_layer,
|
||
|
drop=drop,
|
||
|
)
|
||
|
|
||
|
# Drop path
|
||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||
|
|
||
|
# Layer Scale
|
||
|
self.use_layer_scale = use_layer_scale
|
||
|
if use_layer_scale:
|
||
|
self.layer_scale_1 = nn.Parameter(
|
||
|
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
||
|
)
|
||
|
self.layer_scale_2 = nn.Parameter(
|
||
|
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.use_layer_scale:
|
||
|
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
|
||
|
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
|
||
|
else:
|
||
|
x = x + self.drop_path(self.token_mixer(self.norm(x)))
|
||
|
x = x + self.drop_path(self.convffn(x))
|
||
|
return x
|
||
|
|
||
|
|
||
|
def basic_blocks(
|
||
|
dim: int,
|
||
|
block_index: int,
|
||
|
num_blocks: List[int],
|
||
|
token_mixer_type: str,
|
||
|
kernel_size: int = 3,
|
||
|
mlp_ratio: float = 4.0,
|
||
|
act_layer: nn.Module = nn.GELU,
|
||
|
norm_layer: nn.Module = nn.BatchNorm2d,
|
||
|
drop_rate: float = 0.0,
|
||
|
drop_path_rate: float = 0.0,
|
||
|
use_layer_scale: bool = True,
|
||
|
layer_scale_init_value: float = 1e-5,
|
||
|
inference_mode=False,
|
||
|
) -> nn.Sequential:
|
||
|
"""Build FastViT blocks within a stage.
|
||
|
|
||
|
Args:
|
||
|
dim: Number of embedding dimensions.
|
||
|
block_index: block index.
|
||
|
num_blocks: List containing number of blocks per stage.
|
||
|
token_mixer_type: Token mixer type.
|
||
|
kernel_size: Kernel size for repmixer.
|
||
|
mlp_ratio: MLP expansion ratio.
|
||
|
act_layer: Activation layer.
|
||
|
norm_layer: Normalization layer.
|
||
|
drop_rate: Dropout rate.
|
||
|
drop_path_rate: Drop path rate.
|
||
|
use_layer_scale: Flag to turn on layer scale regularization.
|
||
|
layer_scale_init_value: Layer scale value at initialization.
|
||
|
inference_mode: Flag to instantiate block in inference mode.
|
||
|
|
||
|
Returns:
|
||
|
nn.Sequential object of all the blocks within the stage.
|
||
|
"""
|
||
|
blocks = []
|
||
|
for block_idx in range(num_blocks[block_index]):
|
||
|
block_dpr = (
|
||
|
drop_path_rate
|
||
|
* (block_idx + sum(num_blocks[:block_index]))
|
||
|
/ (sum(num_blocks) - 1)
|
||
|
)
|
||
|
if token_mixer_type == "repmixer":
|
||
|
blocks.append(
|
||
|
RepMixerBlock(
|
||
|
dim,
|
||
|
kernel_size=kernel_size,
|
||
|
mlp_ratio=mlp_ratio,
|
||
|
act_layer=act_layer,
|
||
|
drop=drop_rate,
|
||
|
drop_path=block_dpr,
|
||
|
use_layer_scale=use_layer_scale,
|
||
|
layer_scale_init_value=layer_scale_init_value,
|
||
|
inference_mode=inference_mode,
|
||
|
)
|
||
|
)
|
||
|
elif token_mixer_type == "attention":
|
||
|
blocks.append(
|
||
|
AttentionBlock(
|
||
|
dim,
|
||
|
mlp_ratio=mlp_ratio,
|
||
|
act_layer=act_layer,
|
||
|
norm_layer=norm_layer,
|
||
|
drop=drop_rate,
|
||
|
drop_path=block_dpr,
|
||
|
use_layer_scale=use_layer_scale,
|
||
|
layer_scale_init_value=layer_scale_init_value,
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Token mixer type: {} not supported".format(token_mixer_type)
|
||
|
)
|
||
|
blocks = nn.Sequential(*blocks)
|
||
|
|
||
|
return blocks
|
||
|
|
||
|
|
||
|
class FastVit(nn.Module):
|
||
|
"""
|
||
|
This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
layers,
|
||
|
token_mixers: Tuple[str, ...],
|
||
|
embed_dims=None,
|
||
|
mlp_ratios=None,
|
||
|
downsamples=None,
|
||
|
repmixer_kernel_size=3,
|
||
|
norm_layer: nn.Module = nn.BatchNorm2d,
|
||
|
act_layer: nn.Module = nn.GELU,
|
||
|
num_classes=1000,
|
||
|
pos_embs=None,
|
||
|
down_patch_size=7,
|
||
|
down_stride=2,
|
||
|
drop_rate=0.0,
|
||
|
drop_path_rate=0.0,
|
||
|
use_layer_scale=True,
|
||
|
layer_scale_init_value=1e-5,
|
||
|
fork_feat=False,
|
||
|
init_cfg=None,
|
||
|
pretrained=None,
|
||
|
cls_ratio=2.0,
|
||
|
inference_mode=False,
|
||
|
**kwargs,
|
||
|
) -> None:
|
||
|
|
||
|
super().__init__()
|
||
|
|
||
|
if not fork_feat:
|
||
|
self.num_classes = num_classes
|
||
|
self.fork_feat = fork_feat
|
||
|
|
||
|
if pos_embs is None:
|
||
|
pos_embs = [None] * len(layers)
|
||
|
|
||
|
# Convolutional stem
|
||
|
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode)
|
||
|
|
||
|
# Build the main stages of the network architecture
|
||
|
network = []
|
||
|
for i in range(len(layers)):
|
||
|
# Add position embeddings if requested
|
||
|
if pos_embs[i] is not None:
|
||
|
network.append(pos_embs[i](
|
||
|
embed_dims[i],
|
||
|
embed_dims[i],
|
||
|
inference_mode=inference_mode,
|
||
|
))
|
||
|
stage = basic_blocks(
|
||
|
embed_dims[i],
|
||
|
i,
|
||
|
layers,
|
||
|
token_mixer_type=token_mixers[i],
|
||
|
kernel_size=repmixer_kernel_size,
|
||
|
mlp_ratio=mlp_ratios[i],
|
||
|
act_layer=act_layer,
|
||
|
norm_layer=norm_layer,
|
||
|
drop_rate=drop_rate,
|
||
|
drop_path_rate=drop_path_rate,
|
||
|
use_layer_scale=use_layer_scale,
|
||
|
layer_scale_init_value=layer_scale_init_value,
|
||
|
inference_mode=inference_mode,
|
||
|
)
|
||
|
network.append(stage)
|
||
|
if i >= len(layers) - 1:
|
||
|
break
|
||
|
|
||
|
# Patch merging/downsampling between stages.
|
||
|
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
||
|
network.append(
|
||
|
PatchEmbed(
|
||
|
patch_size=down_patch_size,
|
||
|
stride=down_stride,
|
||
|
in_chs=embed_dims[i],
|
||
|
embed_dim=embed_dims[i + 1],
|
||
|
inference_mode=inference_mode,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
self.network = nn.ModuleList(network)
|
||
|
|
||
|
# For segmentation and detection, extract intermediate output
|
||
|
if self.fork_feat:
|
||
|
# add a norm layer for each output
|
||
|
self.out_indices = [0, 2, 4, 6]
|
||
|
for i_emb, i_layer in enumerate(self.out_indices):
|
||
|
if i_emb == 0 and os.environ.get("FORK_LAST3", None):
|
||
|
"""For RetinaNet, `start_level=1`. The first norm layer will not used.
|
||
|
cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
|
||
|
"""
|
||
|
layer = nn.Identity()
|
||
|
else:
|
||
|
layer = norm_layer(embed_dims[i_emb])
|
||
|
layer_name = f"norm{i_layer}"
|
||
|
self.add_module(layer_name, layer)
|
||
|
else:
|
||
|
# Classifier head
|
||
|
self.gap = nn.AdaptiveAvgPool2d(output_size=1)
|
||
|
self.conv_exp = MobileOneBlock(
|
||
|
in_chs=embed_dims[-1],
|
||
|
out_chs=int(embed_dims[-1] * cls_ratio),
|
||
|
kernel_size=3,
|
||
|
stride=1,
|
||
|
group_size=1,
|
||
|
inference_mode=inference_mode,
|
||
|
use_se=True,
|
||
|
num_conv_branches=1,
|
||
|
)
|
||
|
self.head = (
|
||
|
nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
|
||
|
if num_classes > 0
|
||
|
else nn.Identity()
|
||
|
)
|
||
|
|
||
|
self.apply(self.cls_init_weights)
|
||
|
self.init_cfg = copy.deepcopy(init_cfg)
|
||
|
|
||
|
# load pre-trained model
|
||
|
if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
|
||
|
self.init_weights()
|
||
|
|
||
|
def cls_init_weights(self, m: nn.Module) -> None:
|
||
|
"""Init. for classification"""
|
||
|
if isinstance(m, nn.Linear):
|
||
|
trunc_normal_(m.weight, std=0.02)
|
||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
|
||
|
@staticmethod
|
||
|
def _scrub_checkpoint(checkpoint, model):
|
||
|
sterile_dict = {}
|
||
|
for k1, v1 in checkpoint.items():
|
||
|
if k1 not in model.state_dict():
|
||
|
continue
|
||
|
if v1.shape == model.state_dict()[k1].shape:
|
||
|
sterile_dict[k1] = v1
|
||
|
return sterile_dict
|
||
|
|
||
|
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x = self.patch_embed(x)
|
||
|
return x
|
||
|
|
||
|
def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
outs = []
|
||
|
for idx, block in enumerate(self.network):
|
||
|
x = block(x)
|
||
|
if self.fork_feat and idx in self.out_indices:
|
||
|
norm_layer = getattr(self, f"norm{idx}")
|
||
|
x_out = norm_layer(x)
|
||
|
outs.append(x_out)
|
||
|
if self.fork_feat:
|
||
|
# output the features of four stages for dense prediction
|
||
|
return outs
|
||
|
# output only the features of last layer for image classification
|
||
|
return x
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
# input embedding
|
||
|
x = self.forward_embeddings(x)
|
||
|
# through backbone
|
||
|
x = self.forward_tokens(x)
|
||
|
if self.fork_feat:
|
||
|
# output features of four stages for dense prediction
|
||
|
return x
|
||
|
# for image classification
|
||
|
x = self.conv_exp(x)
|
||
|
x = self.gap(x)
|
||
|
x = x.view(x.size(0), -1)
|
||
|
cls_out = self.head(x)
|
||
|
return cls_out
|
||
|
|
||
|
|
||
|
def _cfg(url="", **kwargs):
|
||
|
return {
|
||
|
"url": url,
|
||
|
"num_classes": 1000,
|
||
|
"input_size": (3, 256, 256),
|
||
|
"pool_size": None,
|
||
|
"crop_pct": 0.95,
|
||
|
"interpolation": "bicubic",
|
||
|
"mean": IMAGENET_DEFAULT_MEAN,
|
||
|
"std": IMAGENET_DEFAULT_STD,
|
||
|
"classifier": "head",
|
||
|
**kwargs,
|
||
|
}
|
||
|
|
||
|
|
||
|
default_cfgs = {
|
||
|
"fastvit_t": _cfg(crop_pct=0.9),
|
||
|
"fastvit_s": _cfg(crop_pct=0.9),
|
||
|
"fastvit_m": _cfg(crop_pct=0.95),
|
||
|
}
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def fastvit_t8(pretrained=False, **kwargs):
|
||
|
"""Instantiate FastViT-T8 model variant."""
|
||
|
layers = [2, 2, 4, 2]
|
||
|
embed_dims = [48, 96, 192, 384]
|
||
|
mlp_ratios = [3, 3, 3, 3]
|
||
|
downsamples = [True, True, True, True]
|
||
|
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
|
||
|
model = FastVit(
|
||
|
layers,
|
||
|
token_mixers=token_mixers,
|
||
|
embed_dims=embed_dims,
|
||
|
mlp_ratios=mlp_ratios,
|
||
|
downsamples=downsamples,
|
||
|
**kwargs,
|
||
|
)
|
||
|
model.default_cfg = default_cfgs["fastvit_t"]
|
||
|
if pretrained:
|
||
|
raise ValueError("Functionality not implemented.")
|
||
|
return model
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def fastvit_t12(pretrained=False, **kwargs):
|
||
|
"""Instantiate FastViT-T12 model variant."""
|
||
|
layers = [2, 2, 6, 2]
|
||
|
embed_dims = [64, 128, 256, 512]
|
||
|
mlp_ratios = [3, 3, 3, 3]
|
||
|
downsamples = [True, True, True, True]
|
||
|
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
|
||
|
model = FastVit(
|
||
|
layers,
|
||
|
token_mixers=token_mixers,
|
||
|
embed_dims=embed_dims,
|
||
|
mlp_ratios=mlp_ratios,
|
||
|
downsamples=downsamples,
|
||
|
**kwargs,
|
||
|
)
|
||
|
model.default_cfg = default_cfgs["fastvit_t"]
|
||
|
if pretrained:
|
||
|
raise ValueError("Functionality not implemented.")
|
||
|
return model
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def fastvit_s12(pretrained=False, **kwargs):
|
||
|
"""Instantiate FastViT-S12 model variant."""
|
||
|
layers = [2, 2, 6, 2]
|
||
|
embed_dims = [64, 128, 256, 512]
|
||
|
mlp_ratios = [4, 4, 4, 4]
|
||
|
downsamples = [True, True, True, True]
|
||
|
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
|
||
|
model = FastVit(
|
||
|
layers,
|
||
|
token_mixers=token_mixers,
|
||
|
embed_dims=embed_dims,
|
||
|
mlp_ratios=mlp_ratios,
|
||
|
downsamples=downsamples,
|
||
|
**kwargs,
|
||
|
)
|
||
|
model.default_cfg = default_cfgs["fastvit_s"]
|
||
|
if pretrained:
|
||
|
raise ValueError("Functionality not implemented.")
|
||
|
return model
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def fastvit_sa12(pretrained=False, **kwargs):
|
||
|
"""Instantiate FastViT-SA12 model variant."""
|
||
|
layers = [2, 2, 6, 2]
|
||
|
embed_dims = [64, 128, 256, 512]
|
||
|
mlp_ratios = [4, 4, 4, 4]
|
||
|
downsamples = [True, True, True, True]
|
||
|
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||
|
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||
|
model = FastVit(
|
||
|
layers,
|
||
|
token_mixers=token_mixers,
|
||
|
embed_dims=embed_dims,
|
||
|
pos_embs=pos_embs,
|
||
|
mlp_ratios=mlp_ratios,
|
||
|
downsamples=downsamples,
|
||
|
**kwargs,
|
||
|
)
|
||
|
model.default_cfg = default_cfgs["fastvit_s"]
|
||
|
if pretrained:
|
||
|
raise ValueError("Functionality not implemented.")
|
||
|
return model
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def fastvit_sa24(pretrained=False, **kwargs):
|
||
|
"""Instantiate FastViT-SA24 model variant."""
|
||
|
layers = [4, 4, 12, 4]
|
||
|
embed_dims = [64, 128, 256, 512]
|
||
|
mlp_ratios = [4, 4, 4, 4]
|
||
|
downsamples = [True, True, True, True]
|
||
|
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||
|
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||
|
model = FastVit(
|
||
|
layers,
|
||
|
token_mixers=token_mixers,
|
||
|
embed_dims=embed_dims,
|
||
|
pos_embs=pos_embs,
|
||
|
mlp_ratios=mlp_ratios,
|
||
|
downsamples=downsamples,
|
||
|
**kwargs,
|
||
|
)
|
||
|
model.default_cfg = default_cfgs["fastvit_s"]
|
||
|
if pretrained:
|
||
|
raise ValueError("Functionality not implemented.")
|
||
|
return model
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def fastvit_sa36(pretrained=False, **kwargs):
|
||
|
"""Instantiate FastViT-SA36 model variant."""
|
||
|
layers = [6, 6, 18, 6]
|
||
|
embed_dims = [64, 128, 256, 512]
|
||
|
mlp_ratios = [4, 4, 4, 4]
|
||
|
downsamples = [True, True, True, True]
|
||
|
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||
|
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||
|
model = FastVit(
|
||
|
layers,
|
||
|
embed_dims=embed_dims,
|
||
|
token_mixers=token_mixers,
|
||
|
pos_embs=pos_embs,
|
||
|
mlp_ratios=mlp_ratios,
|
||
|
downsamples=downsamples,
|
||
|
layer_scale_init_value=1e-6,
|
||
|
**kwargs,
|
||
|
)
|
||
|
model.default_cfg = default_cfgs["fastvit_m"]
|
||
|
if pretrained:
|
||
|
raise ValueError("Functionality not implemented.")
|
||
|
return model
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def fastvit_ma36(pretrained=False, **kwargs):
|
||
|
"""Instantiate FastViT-MA36 model variant."""
|
||
|
layers = [6, 6, 18, 6]
|
||
|
embed_dims = [76, 152, 304, 608]
|
||
|
mlp_ratios = [4, 4, 4, 4]
|
||
|
downsamples = [True, True, True, True]
|
||
|
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||
|
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||
|
model = FastVit(
|
||
|
layers,
|
||
|
embed_dims=embed_dims,
|
||
|
token_mixers=token_mixers,
|
||
|
pos_embs=pos_embs,
|
||
|
mlp_ratios=mlp_ratios,
|
||
|
downsamples=downsamples,
|
||
|
layer_scale_init_value=1e-6,
|
||
|
**kwargs,
|
||
|
)
|
||
|
model.default_cfg = default_cfgs["fastvit_m"]
|
||
|
if pretrained:
|
||
|
raise ValueError("Functionality not implemented.")
|
||
|
return model
|