mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More work on FastViT, use own impl of MobileOne, validation working with remapped weight, more refactor TODO
This commit is contained in:
parent
c7a20cec13
commit
8474508d07
@ -12,9 +12,255 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropPath, trunc_normal_
|
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
from .byobnet import MobileOneBlock
|
|
||||||
|
|
||||||
|
def num_groups(group_size, channels):
|
||||||
|
if not group_size: # 0 or None
|
||||||
|
return 1 # normal conv with 1 group
|
||||||
|
else:
|
||||||
|
# NOTE group_size == 1 -> depthwise conv
|
||||||
|
assert channels % group_size == 0
|
||||||
|
return channels // group_size
|
||||||
|
|
||||||
|
|
||||||
|
class MobileOneBlock(nn.Module):
|
||||||
|
"""MobileOne building block.
|
||||||
|
|
||||||
|
This block has a multi-branched architecture at train-time
|
||||||
|
and plain-CNN style architecture at inference time
|
||||||
|
For more details, please refer to our paper:
|
||||||
|
`An Improved One millisecond Mobile Backbone` -
|
||||||
|
https://arxiv.org/pdf/2206.04040.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chs: int,
|
||||||
|
out_chs: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
group_size: int = 0,
|
||||||
|
inference_mode: bool = False,
|
||||||
|
use_se: bool = False,
|
||||||
|
use_act: bool = True,
|
||||||
|
use_scale_branch: bool = True,
|
||||||
|
num_conv_branches: int = 1,
|
||||||
|
act_layer: nn.Module = nn.GELU,
|
||||||
|
) -> None:
|
||||||
|
"""Construct a MobileOneBlock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_chs: Number of channels in the input.
|
||||||
|
out_chs: Number of channels produced by the block.
|
||||||
|
kernel_size: Size of the convolution kernel.
|
||||||
|
stride: Stride size.
|
||||||
|
dilation: Kernel dilation factor.
|
||||||
|
groups: Group number.
|
||||||
|
inference_mode: If True, instantiates model in inference mode.
|
||||||
|
use_se: Whether to use SE-ReLU activations.
|
||||||
|
use_act: Whether to use activation. Default: ``True``
|
||||||
|
use_scale_branch: Whether to use scale branch. Default: ``True``
|
||||||
|
num_conv_branches: Number of linear conv branches.
|
||||||
|
"""
|
||||||
|
super(MobileOneBlock, self).__init__()
|
||||||
|
self.inference_mode = inference_mode
|
||||||
|
self.groups = num_groups(group_size, in_chs)
|
||||||
|
self.stride = stride
|
||||||
|
self.dilation = dilation
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.in_chs = in_chs
|
||||||
|
self.out_chs = out_chs
|
||||||
|
self.num_conv_branches = num_conv_branches
|
||||||
|
|
||||||
|
# Check if SE-ReLU is requested
|
||||||
|
self.se = SqueezeExcite(out_chs) if use_se else nn.Identity()
|
||||||
|
|
||||||
|
if inference_mode:
|
||||||
|
self.reparam_conv = create_conv2d(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=self.groups,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Re-parameterizable skip connection
|
||||||
|
self.reparam_conv = None
|
||||||
|
|
||||||
|
self.rbr_skip = (
|
||||||
|
nn.BatchNorm2d(num_features=in_chs)
|
||||||
|
if out_chs == in_chs and stride == 1
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-parameterizable conv branches
|
||||||
|
if num_conv_branches > 0:
|
||||||
|
rbr_conv = list()
|
||||||
|
for _ in range(self.num_conv_branches):
|
||||||
|
rbr_conv.append(ConvNormAct(
|
||||||
|
self.in_chs,
|
||||||
|
self.out_chs,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
groups=self.groups,
|
||||||
|
apply_act=False,
|
||||||
|
))
|
||||||
|
self.rbr_conv = nn.ModuleList(rbr_conv)
|
||||||
|
else:
|
||||||
|
self.rbr_conv = None
|
||||||
|
|
||||||
|
# Re-parameterizable scale branch
|
||||||
|
self.rbr_scale = None
|
||||||
|
if kernel_size > 1 and use_scale_branch:
|
||||||
|
self.rbr_scale = ConvNormAct(
|
||||||
|
self.in_chs,
|
||||||
|
self.out_chs,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=self.stride,
|
||||||
|
groups=self.groups,
|
||||||
|
apply_act=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.act = act_layer() if use_act else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply forward pass."""
|
||||||
|
# Inference mode forward pass.
|
||||||
|
if self.reparam_conv is not None:
|
||||||
|
return self.act(self.se(self.reparam_conv(x)))
|
||||||
|
|
||||||
|
# Multi-branched train-time forward pass.
|
||||||
|
# Skip branch output
|
||||||
|
identity_out = 0
|
||||||
|
if self.rbr_skip is not None:
|
||||||
|
identity_out = self.rbr_skip(x)
|
||||||
|
|
||||||
|
# Scale branch output
|
||||||
|
scale_out = 0
|
||||||
|
if self.rbr_scale is not None:
|
||||||
|
scale_out = self.rbr_scale(x)
|
||||||
|
|
||||||
|
# Other branches
|
||||||
|
out = scale_out + identity_out
|
||||||
|
if self.rbr_conv is not None:
|
||||||
|
for ix in range(self.num_conv_branches):
|
||||||
|
out += self.rbr_conv[ix](x)
|
||||||
|
|
||||||
|
return self.act(self.se(out))
|
||||||
|
|
||||||
|
def reparameterize(self):
|
||||||
|
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
||||||
|
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
||||||
|
architecture used at training time to obtain a plain CNN-like structure
|
||||||
|
for inference.
|
||||||
|
"""
|
||||||
|
if self.inference_mode:
|
||||||
|
return
|
||||||
|
kernel, bias = self._get_kernel_bias()
|
||||||
|
self.reparam_conv = create_conv2d(
|
||||||
|
in_channels=self.in_chs,
|
||||||
|
out_channels=self.out_chs,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
dilation=self.dilation,
|
||||||
|
groups=self.groups,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.reparam_conv.weight.data = kernel
|
||||||
|
self.reparam_conv.bias.data = bias
|
||||||
|
|
||||||
|
# Delete un-used branches
|
||||||
|
for para in self.parameters():
|
||||||
|
para.detach_()
|
||||||
|
self.__delattr__("rbr_conv")
|
||||||
|
self.__delattr__("rbr_scale")
|
||||||
|
if hasattr(self, "rbr_skip"):
|
||||||
|
self.__delattr__("rbr_skip")
|
||||||
|
|
||||||
|
self.inference_mode = True
|
||||||
|
|
||||||
|
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Method to obtain re-parameterized kernel and bias.
|
||||||
|
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (kernel, bias) after fusing branches.
|
||||||
|
"""
|
||||||
|
# get weights and bias of scale branch
|
||||||
|
kernel_scale = 0
|
||||||
|
bias_scale = 0
|
||||||
|
if self.rbr_scale is not None:
|
||||||
|
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
|
||||||
|
# Pad scale branch kernel to match conv branch kernel size.
|
||||||
|
pad = self.kernel_size // 2
|
||||||
|
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
||||||
|
|
||||||
|
# get weights and bias of skip branch
|
||||||
|
kernel_identity = 0
|
||||||
|
bias_identity = 0
|
||||||
|
if self.rbr_skip is not None:
|
||||||
|
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
|
||||||
|
|
||||||
|
# get weights and bias of conv branches
|
||||||
|
kernel_conv = 0
|
||||||
|
bias_conv = 0
|
||||||
|
if self.rbr_conv is not None:
|
||||||
|
for ix in range(self.num_conv_branches):
|
||||||
|
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
|
||||||
|
kernel_conv += _kernel
|
||||||
|
bias_conv += _bias
|
||||||
|
|
||||||
|
kernel_final = kernel_conv + kernel_scale + kernel_identity
|
||||||
|
bias_final = bias_conv + bias_scale + bias_identity
|
||||||
|
return kernel_final, bias_final
|
||||||
|
|
||||||
|
def _fuse_bn_tensor(
|
||||||
|
self, branch: Union[nn.Sequential, nn.BatchNorm2d]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Method to fuse batchnorm layer with preceeding conv layer.
|
||||||
|
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch: Sequence of ops to be fused.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (kernel, bias) after fusing batchnorm.
|
||||||
|
"""
|
||||||
|
if isinstance(branch, nn.Sequential):
|
||||||
|
kernel = branch.conv.weight
|
||||||
|
running_mean = branch.bn.running_mean
|
||||||
|
running_var = branch.bn.running_var
|
||||||
|
gamma = branch.bn.weight
|
||||||
|
beta = branch.bn.bias
|
||||||
|
eps = branch.bn.eps
|
||||||
|
else:
|
||||||
|
assert isinstance(branch, nn.BatchNorm2d)
|
||||||
|
if not hasattr(self, "id_tensor"):
|
||||||
|
input_dim = self.in_chs // self.groups
|
||||||
|
kernel_value = torch.zeros(
|
||||||
|
(self.in_chs, input_dim, self.kernel_size, self.kernel_size),
|
||||||
|
dtype=branch.weight.dtype,
|
||||||
|
device=branch.weight.device,
|
||||||
|
)
|
||||||
|
for i in range(self.in_chs):
|
||||||
|
kernel_value[
|
||||||
|
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
|
||||||
|
] = 1
|
||||||
|
self.id_tensor = kernel_value
|
||||||
|
kernel = self.id_tensor
|
||||||
|
running_mean = branch.running_mean
|
||||||
|
running_var = branch.running_var
|
||||||
|
gamma = branch.weight
|
||||||
|
beta = branch.bias
|
||||||
|
eps = branch.eps
|
||||||
|
std = (running_var + eps).sqrt()
|
||||||
|
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||||
|
return kernel * t, beta - running_mean * gamma / std
|
||||||
|
|
||||||
|
|
||||||
class ReparamLargeKernelConv(nn.Module):
|
class ReparamLargeKernelConv(nn.Module):
|
||||||
@ -32,10 +278,10 @@ class ReparamLargeKernelConv(nn.Module):
|
|||||||
out_chs: int,
|
out_chs: int,
|
||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
groups: int,
|
group_size: int,
|
||||||
small_kernel: int,
|
small_kernel: Optional[int] = None,
|
||||||
inference_mode: bool = False,
|
inference_mode: bool = False,
|
||||||
act_layer: nn.Module = nn.GELU(),
|
act_layer: Optional[nn.Module] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct a ReparamLargeKernelConv module.
|
"""Construct a ReparamLargeKernelConv module.
|
||||||
|
|
||||||
@ -44,55 +290,63 @@ class ReparamLargeKernelConv(nn.Module):
|
|||||||
out_chs: Number of output channels.
|
out_chs: Number of output channels.
|
||||||
kernel_size: Kernel size of the large kernel conv branch.
|
kernel_size: Kernel size of the large kernel conv branch.
|
||||||
stride: Stride size. Default: 1
|
stride: Stride size. Default: 1
|
||||||
groups: Group number. Default: 1
|
group_size: Group size. Default: 1
|
||||||
small_kernel: Kernel size of small kernel conv branch.
|
small_kernel: Kernel size of small kernel conv branch.
|
||||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||||
act_layer: Activation module. Default: ``nn.GELU``
|
act_layer: Activation module. Default: ``nn.GELU``
|
||||||
"""
|
"""
|
||||||
super(ReparamLargeKernelConv, self).__init__()
|
super(ReparamLargeKernelConv, self).__init__()
|
||||||
|
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.groups = groups
|
self.groups = num_groups(group_size, in_chs)
|
||||||
self.in_chs = in_chs
|
self.in_chs = in_chs
|
||||||
self.out_chs = out_chs
|
self.out_chs = out_chs
|
||||||
self.act_layer = act_layer
|
|
||||||
|
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.small_kernel = small_kernel
|
self.small_kernel = small_kernel
|
||||||
self.padding = kernel_size // 2
|
|
||||||
if inference_mode:
|
if inference_mode:
|
||||||
self.lkb_reparam = nn.Conv2d(
|
self.lkb_reparam = create_conv2d(
|
||||||
in_chs=in_chs,
|
in_chs,
|
||||||
out_chs=out_chs,
|
out_chs,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=self.padding,
|
|
||||||
dilation=1,
|
dilation=1,
|
||||||
groups=groups,
|
groups=self.groups,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.lkb_origin = self._conv_bn(
|
self.lkb_reparam = None
|
||||||
kernel_size=kernel_size, padding=self.padding
|
self.lkb_origin = ConvNormAct(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
groups=self.groups,
|
||||||
|
apply_act=False,
|
||||||
)
|
)
|
||||||
if small_kernel is not None:
|
if small_kernel is not None:
|
||||||
assert (
|
assert (
|
||||||
small_kernel <= kernel_size
|
small_kernel <= kernel_size
|
||||||
), "The kernel size for re-param cannot be larger than the large kernel!"
|
), "The kernel size for re-param cannot be larger than the large kernel!"
|
||||||
self.small_conv = self._conv_bn(
|
self.small_conv = ConvNormAct(
|
||||||
kernel_size=small_kernel, padding=small_kernel // 2
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=small_kernel,
|
||||||
|
stride=self.stride,
|
||||||
|
groups=self.groups,
|
||||||
|
apply_act=False,
|
||||||
)
|
)
|
||||||
|
# FIXME output of this act was not used in original impl, likely due to bug
|
||||||
|
self.act = act_layer() if act_layer is not None else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply forward pass."""
|
"""Apply forward pass."""
|
||||||
if hasattr(self, "lkb_reparam"):
|
if self.lkb_reparam is not None:
|
||||||
out = self.lkb_reparam(x)
|
out = self.lkb_reparam(x)
|
||||||
else:
|
else:
|
||||||
out = self.lkb_origin(x)
|
out = self.lkb_origin(x)
|
||||||
if hasattr(self, "small_conv"):
|
if self.small_conv is not None:
|
||||||
out += self.small_conv(x)
|
out = out + self.small_conv(x)
|
||||||
|
out = self.act(out)
|
||||||
self.act_layer(out)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -119,12 +373,11 @@ class ReparamLargeKernelConv(nn.Module):
|
|||||||
for inference.
|
for inference.
|
||||||
"""
|
"""
|
||||||
eq_k, eq_b = self.get_kernel_bias()
|
eq_k, eq_b = self.get_kernel_bias()
|
||||||
self.lkb_reparam = nn.Conv2d(
|
self.lkb_reparam = create_conv2d(
|
||||||
in_chs=self.in_chs,
|
self.in_chs,
|
||||||
out_chs=self.out_chs,
|
self.out_chs,
|
||||||
kernel_size=self.kernel_size,
|
kernel_size=self.kernel_size,
|
||||||
stride=self.stride,
|
stride=self.stride,
|
||||||
padding=self.padding,
|
|
||||||
dilation=self.lkb_origin.conv.dilation,
|
dilation=self.lkb_origin.conv.dilation,
|
||||||
groups=self.groups,
|
groups=self.groups,
|
||||||
bias=True,
|
bias=True,
|
||||||
@ -159,35 +412,11 @@ class ReparamLargeKernelConv(nn.Module):
|
|||||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||||
return kernel * t, beta - running_mean * gamma / std
|
return kernel * t, beta - running_mean * gamma / std
|
||||||
|
|
||||||
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
|
|
||||||
"""Helper method to construct conv-batchnorm layers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kernel_size: Size of the convolution kernel.
|
|
||||||
padding: Zero-padding size.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A nn.Sequential Conv-BN module.
|
|
||||||
"""
|
|
||||||
mod_list = nn.Sequential()
|
|
||||||
mod_list.add_module(
|
|
||||||
"conv",
|
|
||||||
nn.Conv2d(
|
|
||||||
in_chs=self.in_chs,
|
|
||||||
out_chs=self.out_chs,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=padding,
|
|
||||||
groups=self.groups,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_chs))
|
|
||||||
return mod_list
|
|
||||||
|
|
||||||
|
|
||||||
def convolutional_stem(
|
def convolutional_stem(
|
||||||
in_chs: int, out_chs: int, inference_mode: bool = False
|
in_chs: int,
|
||||||
|
out_chs: int,
|
||||||
|
inference_mode: bool = False
|
||||||
) -> nn.Sequential:
|
) -> nn.Sequential:
|
||||||
"""Build convolutional stem with MobileOne blocks.
|
"""Build convolutional stem with MobileOne blocks.
|
||||||
|
|
||||||
@ -206,8 +435,6 @@ def convolutional_stem(
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
use_se=False,
|
|
||||||
num_conv_branches=1,
|
|
||||||
),
|
),
|
||||||
MobileOneBlock(
|
MobileOneBlock(
|
||||||
in_chs=out_chs,
|
in_chs=out_chs,
|
||||||
@ -216,8 +443,6 @@ def convolutional_stem(
|
|||||||
stride=2,
|
stride=2,
|
||||||
group_size=1,
|
group_size=1,
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
use_se=False,
|
|
||||||
num_conv_branches=1,
|
|
||||||
),
|
),
|
||||||
MobileOneBlock(
|
MobileOneBlock(
|
||||||
in_chs=out_chs,
|
in_chs=out_chs,
|
||||||
@ -225,8 +450,6 @@ def convolutional_stem(
|
|||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
use_se=False,
|
|
||||||
num_conv_branches=1,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -237,6 +460,7 @@ class Attention(nn.Module):
|
|||||||
Source modified from:
|
Source modified from:
|
||||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||||
"""
|
"""
|
||||||
|
fused_attn: torch.jit.Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -259,7 +483,8 @@ class Attention(nn.Module):
|
|||||||
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_heads = dim // head_dim
|
self.num_heads = dim // head_dim
|
||||||
self.scale = head_dim**-0.5
|
self.scale = head_dim ** -0.5
|
||||||
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
@ -267,11 +492,9 @@ class Attention(nn.Module):
|
|||||||
self.proj_drop = nn.Dropout(proj_drop)
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
shape = x.shape
|
B, C, H, W = x.shape
|
||||||
B, C, H, W = shape
|
|
||||||
N = H * W
|
N = H * W
|
||||||
if len(shape) == 4:
|
x = x.flatten(2).transpose(-2, -1) # (B, N, C)
|
||||||
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
|
|
||||||
qkv = (
|
qkv = (
|
||||||
self.qkv(x)
|
self.qkv(x)
|
||||||
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||||
@ -279,16 +502,22 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
# trick here to make q@k.t more stable
|
if self.fused_attn:
|
||||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
attn = attn.softmax(dim=-1)
|
q, k, v,
|
||||||
attn = self.attn_drop(attn)
|
dropout_p=self.attn_drop.p,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
if len(shape) == 4:
|
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
||||||
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -314,32 +543,25 @@ class PatchEmbed(nn.Module):
|
|||||||
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
block = list()
|
self.proj = nn.Sequential(
|
||||||
block.append(
|
|
||||||
ReparamLargeKernelConv(
|
ReparamLargeKernelConv(
|
||||||
in_chs=in_chs,
|
in_chs=in_chs,
|
||||||
out_chs=embed_dim,
|
out_chs=embed_dim,
|
||||||
kernel_size=patch_size,
|
kernel_size=patch_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
groups=in_chs,
|
group_size=1,
|
||||||
small_kernel=3,
|
small_kernel=3,
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
)
|
act_layer=None, # activation was not used in original impl
|
||||||
)
|
),
|
||||||
block.append(
|
|
||||||
MobileOneBlock(
|
MobileOneBlock(
|
||||||
in_chs=embed_dim,
|
in_chs=embed_dim,
|
||||||
out_chs=embed_dim,
|
out_chs=embed_dim,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
|
||||||
groups=1,
|
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
use_se=False,
|
|
||||||
num_conv_branches=1,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.proj = nn.Sequential(*block)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
@ -377,8 +599,8 @@ class RepMixer(nn.Module):
|
|||||||
|
|
||||||
if inference_mode:
|
if inference_mode:
|
||||||
self.reparam_conv = nn.Conv2d(
|
self.reparam_conv = nn.Conv2d(
|
||||||
in_chs=self.dim,
|
self.dim,
|
||||||
out_chs=self.dim,
|
self.dim,
|
||||||
kernel_size=self.kernel_size,
|
kernel_size=self.kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=self.kernel_size // 2,
|
padding=self.kernel_size // 2,
|
||||||
@ -386,6 +608,7 @@ class RepMixer(nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self.reparam_conv = None
|
||||||
self.norm = MobileOneBlock(
|
self.norm = MobileOneBlock(
|
||||||
dim,
|
dim,
|
||||||
dim,
|
dim,
|
||||||
@ -404,12 +627,10 @@ class RepMixer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.use_layer_scale = use_layer_scale
|
self.use_layer_scale = use_layer_scale
|
||||||
if use_layer_scale:
|
if use_layer_scale:
|
||||||
self.layer_scale = nn.Parameter(
|
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||||
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if hasattr(self, "reparam_conv"):
|
if self.reparam_conv is not None:
|
||||||
x = self.reparam_conv(x)
|
x = self.reparam_conv(x)
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
@ -444,12 +665,11 @@ class RepMixer(nn.Module):
|
|||||||
)
|
)
|
||||||
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||||||
|
|
||||||
self.reparam_conv = nn.Conv2d(
|
self.reparam_conv = create_conv2d(
|
||||||
in_chs=self.dim,
|
self.dim,
|
||||||
out_chs=self.dim,
|
self.dim,
|
||||||
kernel_size=self.kernel_size,
|
kernel_size=self.kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=self.kernel_size // 2,
|
|
||||||
groups=self.dim,
|
groups=self.dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
@ -487,19 +707,26 @@ class ConvMlp(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
out_chs = out_chs or in_chs
|
out_chs = out_chs or in_chs
|
||||||
hidden_channels = hidden_channels or in_chs
|
hidden_channels = hidden_channels or in_chs
|
||||||
self.conv = nn.Sequential()
|
# self.conv = nn.Sequential()
|
||||||
self.conv.add_module(
|
# self.conv.add_module(
|
||||||
"conv",
|
# "conv",
|
||||||
nn.Conv2d(
|
# nn.Conv2d(
|
||||||
in_chs=in_chs,
|
# in_chs,
|
||||||
out_chs=out_chs,
|
# out_chs,
|
||||||
kernel_size=7,
|
# kernel_size=7,
|
||||||
padding=3,
|
# padding=3,
|
||||||
groups=in_chs,
|
# groups=in_chs,
|
||||||
bias=False,
|
# bias=False,
|
||||||
),
|
# ),
|
||||||
|
# )
|
||||||
|
# self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
|
||||||
|
self.conv = ConvNormAct(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=7,
|
||||||
|
groups=in_chs,
|
||||||
|
apply_act=False,
|
||||||
)
|
)
|
||||||
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
|
|
||||||
self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1)
|
self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1)
|
||||||
self.act = act_layer()
|
self.act = act_layer()
|
||||||
self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1)
|
self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1)
|
||||||
@ -565,27 +792,28 @@ class RepCPE(nn.Module):
|
|||||||
|
|
||||||
if inference_mode:
|
if inference_mode:
|
||||||
self.reparam_conv = nn.Conv2d(
|
self.reparam_conv = nn.Conv2d(
|
||||||
in_chs=self.in_chs,
|
self.in_chs,
|
||||||
out_chs=self.embed_dim,
|
self.embed_dim,
|
||||||
kernel_size=self.spatial_shape,
|
kernel_size=self.spatial_shape,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=int(self.spatial_shape[0] // 2),
|
padding=spatial_shape[0] // 2,
|
||||||
groups=self.embed_dim,
|
groups=self.embed_dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self.reparam_conv = None
|
||||||
self.pe = nn.Conv2d(
|
self.pe = nn.Conv2d(
|
||||||
in_chs,
|
in_chs,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
spatial_shape,
|
spatial_shape,
|
||||||
1,
|
1,
|
||||||
int(spatial_shape[0] // 2),
|
int(spatial_shape[0] // 2),
|
||||||
bias=True,
|
|
||||||
groups=embed_dim,
|
groups=embed_dim,
|
||||||
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if hasattr(self, "reparam_conv"):
|
if self.reparam_conv is not None:
|
||||||
x = self.reparam_conv(x)
|
x = self.reparam_conv(x)
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
@ -620,8 +848,8 @@ class RepCPE(nn.Module):
|
|||||||
|
|
||||||
# Introduce reparam conv
|
# Introduce reparam conv
|
||||||
self.reparam_conv = nn.Conv2d(
|
self.reparam_conv = nn.Conv2d(
|
||||||
in_chs=self.in_chs,
|
self.in_chs,
|
||||||
out_chs=self.embed_dim,
|
self.embed_dim,
|
||||||
kernel_size=self.spatial_shape,
|
kernel_size=self.spatial_shape,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=int(self.spatial_shape[0] // 2),
|
padding=int(self.spatial_shape[0] // 2),
|
||||||
@ -682,10 +910,9 @@ class RepMixerBlock(nn.Module):
|
|||||||
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
||||||
mlp_ratio
|
mlp_ratio
|
||||||
)
|
)
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
||||||
self.convffn = ConvMlp(
|
self.convffn = ConvMlp(
|
||||||
in_chs=dim,
|
in_chs=dim,
|
||||||
hidden_channels=mlp_hidden_dim,
|
hidden_channels=int(dim * mlp_ratio),
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
drop=drop,
|
drop=drop,
|
||||||
)
|
)
|
||||||
@ -696,9 +923,7 @@ class RepMixerBlock(nn.Module):
|
|||||||
# Layer Scale
|
# Layer Scale
|
||||||
self.use_layer_scale = use_layer_scale
|
self.use_layer_scale = use_layer_scale
|
||||||
if use_layer_scale:
|
if use_layer_scale:
|
||||||
self.layer_scale = nn.Parameter(
|
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||||
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.use_layer_scale:
|
if self.use_layer_scale:
|
||||||
@ -763,12 +988,8 @@ class AttentionBlock(nn.Module):
|
|||||||
# Layer Scale
|
# Layer Scale
|
||||||
self.use_layer_scale = use_layer_scale
|
self.use_layer_scale = use_layer_scale
|
||||||
if use_layer_scale:
|
if use_layer_scale:
|
||||||
self.layer_scale_1 = nn.Parameter(
|
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||||
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||||
)
|
|
||||||
self.layer_scale_2 = nn.Parameter(
|
|
||||||
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.use_layer_scale:
|
if self.use_layer_scale:
|
||||||
@ -823,32 +1044,28 @@ def basic_blocks(
|
|||||||
/ (sum(num_blocks) - 1)
|
/ (sum(num_blocks) - 1)
|
||||||
)
|
)
|
||||||
if token_mixer_type == "repmixer":
|
if token_mixer_type == "repmixer":
|
||||||
blocks.append(
|
blocks.append(RepMixerBlock(
|
||||||
RepMixerBlock(
|
dim,
|
||||||
dim,
|
kernel_size=kernel_size,
|
||||||
kernel_size=kernel_size,
|
mlp_ratio=mlp_ratio,
|
||||||
mlp_ratio=mlp_ratio,
|
act_layer=act_layer,
|
||||||
act_layer=act_layer,
|
drop=drop_rate,
|
||||||
drop=drop_rate,
|
drop_path=block_dpr,
|
||||||
drop_path=block_dpr,
|
use_layer_scale=use_layer_scale,
|
||||||
use_layer_scale=use_layer_scale,
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
layer_scale_init_value=layer_scale_init_value,
|
inference_mode=inference_mode,
|
||||||
inference_mode=inference_mode,
|
))
|
||||||
)
|
|
||||||
)
|
|
||||||
elif token_mixer_type == "attention":
|
elif token_mixer_type == "attention":
|
||||||
blocks.append(
|
blocks.append(AttentionBlock(
|
||||||
AttentionBlock(
|
dim,
|
||||||
dim,
|
mlp_ratio=mlp_ratio,
|
||||||
mlp_ratio=mlp_ratio,
|
act_layer=act_layer,
|
||||||
act_layer=act_layer,
|
norm_layer=norm_layer,
|
||||||
norm_layer=norm_layer,
|
drop=drop_rate,
|
||||||
drop=drop_rate,
|
drop_path=block_dpr,
|
||||||
drop_path=block_dpr,
|
use_layer_scale=use_layer_scale,
|
||||||
use_layer_scale=use_layer_scale,
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
layer_scale_init_value=layer_scale_init_value,
|
))
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Token mixer type: {} not supported".format(token_mixer_type)
|
"Token mixer type: {} not supported".format(token_mixer_type)
|
||||||
@ -932,15 +1149,13 @@ class FastVit(nn.Module):
|
|||||||
|
|
||||||
# Patch merging/downsampling between stages.
|
# Patch merging/downsampling between stages.
|
||||||
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
||||||
network.append(
|
network += [PatchEmbed(
|
||||||
PatchEmbed(
|
patch_size=down_patch_size,
|
||||||
patch_size=down_patch_size,
|
stride=down_stride,
|
||||||
stride=down_stride,
|
in_chs=embed_dims[i],
|
||||||
in_chs=embed_dims[i],
|
embed_dim=embed_dims[i + 1],
|
||||||
embed_dim=embed_dims[i + 1],
|
inference_mode=inference_mode,
|
||||||
inference_mode=inference_mode,
|
)]
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.network = nn.ModuleList(network)
|
self.network = nn.ModuleList(network)
|
||||||
|
|
||||||
@ -1054,6 +1269,8 @@ default_cfgs = {
|
|||||||
"fastvit_t": _cfg(crop_pct=0.9),
|
"fastvit_t": _cfg(crop_pct=0.9),
|
||||||
"fastvit_s": _cfg(crop_pct=0.9),
|
"fastvit_s": _cfg(crop_pct=0.9),
|
||||||
"fastvit_m": _cfg(crop_pct=0.95),
|
"fastvit_m": _cfg(crop_pct=0.95),
|
||||||
|
'fastvit_t8': _cfg(
|
||||||
|
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user