mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More work on FastViT, use own impl of MobileOne, validation working with remapped weight, more refactor TODO
This commit is contained in:
parent
c7a20cec13
commit
8474508d07
@ -12,9 +12,255 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_
|
||||
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn
|
||||
from ._registry import register_model
|
||||
from .byobnet import MobileOneBlock
|
||||
|
||||
|
||||
def num_groups(group_size, channels):
|
||||
if not group_size: # 0 or None
|
||||
return 1 # normal conv with 1 group
|
||||
else:
|
||||
# NOTE group_size == 1 -> depthwise conv
|
||||
assert channels % group_size == 0
|
||||
return channels // group_size
|
||||
|
||||
|
||||
class MobileOneBlock(nn.Module):
|
||||
"""MobileOne building block.
|
||||
|
||||
This block has a multi-branched architecture at train-time
|
||||
and plain-CNN style architecture at inference time
|
||||
For more details, please refer to our paper:
|
||||
`An Improved One millisecond Mobile Backbone` -
|
||||
https://arxiv.org/pdf/2206.04040.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
use_act: bool = True,
|
||||
use_scale_branch: bool = True,
|
||||
num_conv_branches: int = 1,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
) -> None:
|
||||
"""Construct a MobileOneBlock module.
|
||||
|
||||
Args:
|
||||
in_chs: Number of channels in the input.
|
||||
out_chs: Number of channels produced by the block.
|
||||
kernel_size: Size of the convolution kernel.
|
||||
stride: Stride size.
|
||||
dilation: Kernel dilation factor.
|
||||
groups: Group number.
|
||||
inference_mode: If True, instantiates model in inference mode.
|
||||
use_se: Whether to use SE-ReLU activations.
|
||||
use_act: Whether to use activation. Default: ``True``
|
||||
use_scale_branch: Whether to use scale branch. Default: ``True``
|
||||
num_conv_branches: Number of linear conv branches.
|
||||
"""
|
||||
super(MobileOneBlock, self).__init__()
|
||||
self.inference_mode = inference_mode
|
||||
self.groups = num_groups(group_size, in_chs)
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.kernel_size = kernel_size
|
||||
self.in_chs = in_chs
|
||||
self.out_chs = out_chs
|
||||
self.num_conv_branches = num_conv_branches
|
||||
|
||||
# Check if SE-ReLU is requested
|
||||
self.se = SqueezeExcite(out_chs) if use_se else nn.Identity()
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = create_conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
# Re-parameterizable skip connection
|
||||
self.reparam_conv = None
|
||||
|
||||
self.rbr_skip = (
|
||||
nn.BatchNorm2d(num_features=in_chs)
|
||||
if out_chs == in_chs and stride == 1
|
||||
else None
|
||||
)
|
||||
|
||||
# Re-parameterizable conv branches
|
||||
if num_conv_branches > 0:
|
||||
rbr_conv = list()
|
||||
for _ in range(self.num_conv_branches):
|
||||
rbr_conv.append(ConvNormAct(
|
||||
self.in_chs,
|
||||
self.out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
groups=self.groups,
|
||||
apply_act=False,
|
||||
))
|
||||
self.rbr_conv = nn.ModuleList(rbr_conv)
|
||||
else:
|
||||
self.rbr_conv = None
|
||||
|
||||
# Re-parameterizable scale branch
|
||||
self.rbr_scale = None
|
||||
if kernel_size > 1 and use_scale_branch:
|
||||
self.rbr_scale = ConvNormAct(
|
||||
self.in_chs,
|
||||
self.out_chs,
|
||||
kernel_size=1,
|
||||
stride=self.stride,
|
||||
groups=self.groups,
|
||||
apply_act=False
|
||||
)
|
||||
|
||||
self.act = act_layer() if use_act else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
# Inference mode forward pass.
|
||||
if self.reparam_conv is not None:
|
||||
return self.act(self.se(self.reparam_conv(x)))
|
||||
|
||||
# Multi-branched train-time forward pass.
|
||||
# Skip branch output
|
||||
identity_out = 0
|
||||
if self.rbr_skip is not None:
|
||||
identity_out = self.rbr_skip(x)
|
||||
|
||||
# Scale branch output
|
||||
scale_out = 0
|
||||
if self.rbr_scale is not None:
|
||||
scale_out = self.rbr_scale(x)
|
||||
|
||||
# Other branches
|
||||
out = scale_out + identity_out
|
||||
if self.rbr_conv is not None:
|
||||
for ix in range(self.num_conv_branches):
|
||||
out += self.rbr_conv[ix](x)
|
||||
|
||||
return self.act(self.se(out))
|
||||
|
||||
def reparameterize(self):
|
||||
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
||||
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
||||
architecture used at training time to obtain a plain CNN-like structure
|
||||
for inference.
|
||||
"""
|
||||
if self.inference_mode:
|
||||
return
|
||||
kernel, bias = self._get_kernel_bias()
|
||||
self.reparam_conv = create_conv2d(
|
||||
in_channels=self.in_chs,
|
||||
out_channels=self.out_chs,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = kernel
|
||||
self.reparam_conv.bias.data = bias
|
||||
|
||||
# Delete un-used branches
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("rbr_conv")
|
||||
self.__delattr__("rbr_scale")
|
||||
if hasattr(self, "rbr_skip"):
|
||||
self.__delattr__("rbr_skip")
|
||||
|
||||
self.inference_mode = True
|
||||
|
||||
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to obtain re-parameterized kernel and bias.
|
||||
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing branches.
|
||||
"""
|
||||
# get weights and bias of scale branch
|
||||
kernel_scale = 0
|
||||
bias_scale = 0
|
||||
if self.rbr_scale is not None:
|
||||
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
|
||||
# Pad scale branch kernel to match conv branch kernel size.
|
||||
pad = self.kernel_size // 2
|
||||
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
||||
|
||||
# get weights and bias of skip branch
|
||||
kernel_identity = 0
|
||||
bias_identity = 0
|
||||
if self.rbr_skip is not None:
|
||||
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
|
||||
|
||||
# get weights and bias of conv branches
|
||||
kernel_conv = 0
|
||||
bias_conv = 0
|
||||
if self.rbr_conv is not None:
|
||||
for ix in range(self.num_conv_branches):
|
||||
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
|
||||
kernel_conv += _kernel
|
||||
bias_conv += _bias
|
||||
|
||||
kernel_final = kernel_conv + kernel_scale + kernel_identity
|
||||
bias_final = bias_conv + bias_scale + bias_identity
|
||||
return kernel_final, bias_final
|
||||
|
||||
def _fuse_bn_tensor(
|
||||
self, branch: Union[nn.Sequential, nn.BatchNorm2d]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to fuse batchnorm layer with preceeding conv layer.
|
||||
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
|
||||
|
||||
Args:
|
||||
branch: Sequence of ops to be fused.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing batchnorm.
|
||||
"""
|
||||
if isinstance(branch, nn.Sequential):
|
||||
kernel = branch.conv.weight
|
||||
running_mean = branch.bn.running_mean
|
||||
running_var = branch.bn.running_var
|
||||
gamma = branch.bn.weight
|
||||
beta = branch.bn.bias
|
||||
eps = branch.bn.eps
|
||||
else:
|
||||
assert isinstance(branch, nn.BatchNorm2d)
|
||||
if not hasattr(self, "id_tensor"):
|
||||
input_dim = self.in_chs // self.groups
|
||||
kernel_value = torch.zeros(
|
||||
(self.in_chs, input_dim, self.kernel_size, self.kernel_size),
|
||||
dtype=branch.weight.dtype,
|
||||
device=branch.weight.device,
|
||||
)
|
||||
for i in range(self.in_chs):
|
||||
kernel_value[
|
||||
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
|
||||
] = 1
|
||||
self.id_tensor = kernel_value
|
||||
kernel = self.id_tensor
|
||||
running_mean = branch.running_mean
|
||||
running_var = branch.running_var
|
||||
gamma = branch.weight
|
||||
beta = branch.bias
|
||||
eps = branch.eps
|
||||
std = (running_var + eps).sqrt()
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
|
||||
class ReparamLargeKernelConv(nn.Module):
|
||||
@ -32,10 +278,10 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
groups: int,
|
||||
small_kernel: int,
|
||||
group_size: int,
|
||||
small_kernel: Optional[int] = None,
|
||||
inference_mode: bool = False,
|
||||
act_layer: nn.Module = nn.GELU(),
|
||||
act_layer: Optional[nn.Module] = None,
|
||||
) -> None:
|
||||
"""Construct a ReparamLargeKernelConv module.
|
||||
|
||||
@ -44,55 +290,63 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
out_chs: Number of output channels.
|
||||
kernel_size: Kernel size of the large kernel conv branch.
|
||||
stride: Stride size. Default: 1
|
||||
groups: Group number. Default: 1
|
||||
group_size: Group size. Default: 1
|
||||
small_kernel: Kernel size of small kernel conv branch.
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
act_layer: Activation module. Default: ``nn.GELU``
|
||||
"""
|
||||
super(ReparamLargeKernelConv, self).__init__()
|
||||
|
||||
self.stride = stride
|
||||
self.groups = groups
|
||||
self.groups = num_groups(group_size, in_chs)
|
||||
self.in_chs = in_chs
|
||||
self.out_chs = out_chs
|
||||
self.act_layer = act_layer
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.small_kernel = small_kernel
|
||||
self.padding = kernel_size // 2
|
||||
if inference_mode:
|
||||
self.lkb_reparam = nn.Conv2d(
|
||||
in_chs=in_chs,
|
||||
out_chs=out_chs,
|
||||
self.lkb_reparam = create_conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
dilation=1,
|
||||
groups=groups,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.lkb_origin = self._conv_bn(
|
||||
kernel_size=kernel_size, padding=self.padding
|
||||
self.lkb_reparam = None
|
||||
self.lkb_origin = ConvNormAct(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
groups=self.groups,
|
||||
apply_act=False,
|
||||
)
|
||||
if small_kernel is not None:
|
||||
assert (
|
||||
small_kernel <= kernel_size
|
||||
), "The kernel size for re-param cannot be larger than the large kernel!"
|
||||
self.small_conv = self._conv_bn(
|
||||
kernel_size=small_kernel, padding=small_kernel // 2
|
||||
self.small_conv = ConvNormAct(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=small_kernel,
|
||||
stride=self.stride,
|
||||
groups=self.groups,
|
||||
apply_act=False,
|
||||
)
|
||||
# FIXME output of this act was not used in original impl, likely due to bug
|
||||
self.act = act_layer() if act_layer is not None else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
if hasattr(self, "lkb_reparam"):
|
||||
if self.lkb_reparam is not None:
|
||||
out = self.lkb_reparam(x)
|
||||
else:
|
||||
out = self.lkb_origin(x)
|
||||
if hasattr(self, "small_conv"):
|
||||
out += self.small_conv(x)
|
||||
|
||||
self.act_layer(out)
|
||||
if self.small_conv is not None:
|
||||
out = out + self.small_conv(x)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -119,12 +373,11 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
for inference.
|
||||
"""
|
||||
eq_k, eq_b = self.get_kernel_bias()
|
||||
self.lkb_reparam = nn.Conv2d(
|
||||
in_chs=self.in_chs,
|
||||
out_chs=self.out_chs,
|
||||
self.lkb_reparam = create_conv2d(
|
||||
self.in_chs,
|
||||
self.out_chs,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.lkb_origin.conv.dilation,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
@ -159,35 +412,11 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
|
||||
"""Helper method to construct conv-batchnorm layers.
|
||||
|
||||
Args:
|
||||
kernel_size: Size of the convolution kernel.
|
||||
padding: Zero-padding size.
|
||||
|
||||
Returns:
|
||||
A nn.Sequential Conv-BN module.
|
||||
"""
|
||||
mod_list = nn.Sequential()
|
||||
mod_list.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_chs=self.in_chs,
|
||||
out_chs=self.out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
padding=padding,
|
||||
groups=self.groups,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_chs))
|
||||
return mod_list
|
||||
|
||||
|
||||
def convolutional_stem(
|
||||
in_chs: int, out_chs: int, inference_mode: bool = False
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
inference_mode: bool = False
|
||||
) -> nn.Sequential:
|
||||
"""Build convolutional stem with MobileOne blocks.
|
||||
|
||||
@ -206,8 +435,6 @@ def convolutional_stem(
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_chs=out_chs,
|
||||
@ -216,8 +443,6 @@ def convolutional_stem(
|
||||
stride=2,
|
||||
group_size=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_chs=out_chs,
|
||||
@ -225,8 +450,6 @@ def convolutional_stem(
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
)
|
||||
|
||||
@ -237,6 +460,7 @@ class Attention(nn.Module):
|
||||
Source modified from:
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -259,7 +483,8 @@ class Attention(nn.Module):
|
||||
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.scale = head_dim**-0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
@ -267,11 +492,9 @@ class Attention(nn.Module):
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.shape
|
||||
B, C, H, W = shape
|
||||
B, C, H, W = x.shape
|
||||
N = H * W
|
||||
if len(shape) == 4:
|
||||
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
|
||||
x = x.flatten(2).transpose(-2, -1) # (B, N, C)
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||
@ -279,16 +502,22 @@ class Attention(nn.Module):
|
||||
)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
# trick here to make q@k.t more stable
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
if self.fused_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
if len(shape) == 4:
|
||||
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
||||
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
@ -314,32 +543,25 @@ class PatchEmbed(nn.Module):
|
||||
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
block = list()
|
||||
block.append(
|
||||
self.proj = nn.Sequential(
|
||||
ReparamLargeKernelConv(
|
||||
in_chs=in_chs,
|
||||
out_chs=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
groups=in_chs,
|
||||
group_size=1,
|
||||
small_kernel=3,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
)
|
||||
block.append(
|
||||
act_layer=None, # activation was not used in original impl
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_chs=embed_dim,
|
||||
out_chs=embed_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
)
|
||||
self.proj = nn.Sequential(*block)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
@ -377,8 +599,8 @@ class RepMixer(nn.Module):
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_chs=self.dim,
|
||||
out_chs=self.dim,
|
||||
self.dim,
|
||||
self.dim,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=1,
|
||||
padding=self.kernel_size // 2,
|
||||
@ -386,6 +608,7 @@ class RepMixer(nn.Module):
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.reparam_conv = None
|
||||
self.norm = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
@ -404,12 +627,10 @@ class RepMixer(nn.Module):
|
||||
)
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
||||
)
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
if self.reparam_conv is not None:
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
@ -444,12 +665,11 @@ class RepMixer(nn.Module):
|
||||
)
|
||||
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||||
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_chs=self.dim,
|
||||
out_chs=self.dim,
|
||||
self.reparam_conv = create_conv2d(
|
||||
self.dim,
|
||||
self.dim,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=1,
|
||||
padding=self.kernel_size // 2,
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
@ -487,19 +707,26 @@ class ConvMlp(nn.Module):
|
||||
super().__init__()
|
||||
out_chs = out_chs or in_chs
|
||||
hidden_channels = hidden_channels or in_chs
|
||||
self.conv = nn.Sequential()
|
||||
self.conv.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_chs=in_chs,
|
||||
out_chs=out_chs,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=in_chs,
|
||||
bias=False,
|
||||
),
|
||||
# self.conv = nn.Sequential()
|
||||
# self.conv.add_module(
|
||||
# "conv",
|
||||
# nn.Conv2d(
|
||||
# in_chs,
|
||||
# out_chs,
|
||||
# kernel_size=7,
|
||||
# padding=3,
|
||||
# groups=in_chs,
|
||||
# bias=False,
|
||||
# ),
|
||||
# )
|
||||
# self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
|
||||
self.conv = ConvNormAct(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=7,
|
||||
groups=in_chs,
|
||||
apply_act=False,
|
||||
)
|
||||
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
|
||||
self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1)
|
||||
@ -565,27 +792,28 @@ class RepCPE(nn.Module):
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_chs=self.in_chs,
|
||||
out_chs=self.embed_dim,
|
||||
self.in_chs,
|
||||
self.embed_dim,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
padding=spatial_shape[0] // 2,
|
||||
groups=self.embed_dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.reparam_conv = None
|
||||
self.pe = nn.Conv2d(
|
||||
in_chs,
|
||||
embed_dim,
|
||||
spatial_shape,
|
||||
1,
|
||||
int(spatial_shape[0] // 2),
|
||||
bias=True,
|
||||
groups=embed_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
if self.reparam_conv is not None:
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
@ -620,8 +848,8 @@ class RepCPE(nn.Module):
|
||||
|
||||
# Introduce reparam conv
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_chs=self.in_chs,
|
||||
out_chs=self.embed_dim,
|
||||
self.in_chs,
|
||||
self.embed_dim,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
@ -682,10 +910,9 @@ class RepMixerBlock(nn.Module):
|
||||
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
||||
mlp_ratio
|
||||
)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvMlp(
|
||||
in_chs=dim,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
hidden_channels=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
@ -696,9 +923,7 @@ class RepMixerBlock(nn.Module):
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
||||
)
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
@ -763,12 +988,8 @@ class AttentionBlock(nn.Module):
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
||||
)
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
||||
)
|
||||
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
@ -823,32 +1044,28 @@ def basic_blocks(
|
||||
/ (sum(num_blocks) - 1)
|
||||
)
|
||||
if token_mixer_type == "repmixer":
|
||||
blocks.append(
|
||||
RepMixerBlock(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
)
|
||||
blocks.append(RepMixerBlock(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
))
|
||||
elif token_mixer_type == "attention":
|
||||
blocks.append(
|
||||
AttentionBlock(
|
||||
dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
)
|
||||
blocks.append(AttentionBlock(
|
||||
dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Token mixer type: {} not supported".format(token_mixer_type)
|
||||
@ -932,15 +1149,13 @@ class FastVit(nn.Module):
|
||||
|
||||
# Patch merging/downsampling between stages.
|
||||
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
||||
network.append(
|
||||
PatchEmbed(
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
in_chs=embed_dims[i],
|
||||
embed_dim=embed_dims[i + 1],
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
)
|
||||
network += [PatchEmbed(
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
in_chs=embed_dims[i],
|
||||
embed_dim=embed_dims[i + 1],
|
||||
inference_mode=inference_mode,
|
||||
)]
|
||||
|
||||
self.network = nn.ModuleList(network)
|
||||
|
||||
@ -1054,6 +1269,8 @@ default_cfgs = {
|
||||
"fastvit_t": _cfg(crop_pct=0.9),
|
||||
"fastvit_s": _cfg(crop_pct=0.9),
|
||||
"fastvit_m": _cfg(crop_pct=0.95),
|
||||
'fastvit_t8': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar')
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user