More work on FastViT, use own impl of MobileOne, validation working with remapped weight, more refactor TODO

This commit is contained in:
Ross Wightman 2023-08-22 11:19:57 -07:00 committed by Ross Wightman
parent c7a20cec13
commit 8474508d07

View File

@ -12,9 +12,255 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn
from ._registry import register_model
from .byobnet import MobileOneBlock
def num_groups(group_size, channels):
if not group_size: # 0 or None
return 1 # normal conv with 1 group
else:
# NOTE group_size == 1 -> depthwise conv
assert channels % group_size == 0
return channels // group_size
class MobileOneBlock(nn.Module):
"""MobileOne building block.
This block has a multi-branched architecture at train-time
and plain-CNN style architecture at inference time
For more details, please refer to our paper:
`An Improved One millisecond Mobile Backbone` -
https://arxiv.org/pdf/2206.04040.pdf
"""
def __init__(
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
group_size: int = 0,
inference_mode: bool = False,
use_se: bool = False,
use_act: bool = True,
use_scale_branch: bool = True,
num_conv_branches: int = 1,
act_layer: nn.Module = nn.GELU,
) -> None:
"""Construct a MobileOneBlock module.
Args:
in_chs: Number of channels in the input.
out_chs: Number of channels produced by the block.
kernel_size: Size of the convolution kernel.
stride: Stride size.
dilation: Kernel dilation factor.
groups: Group number.
inference_mode: If True, instantiates model in inference mode.
use_se: Whether to use SE-ReLU activations.
use_act: Whether to use activation. Default: ``True``
use_scale_branch: Whether to use scale branch. Default: ``True``
num_conv_branches: Number of linear conv branches.
"""
super(MobileOneBlock, self).__init__()
self.inference_mode = inference_mode
self.groups = num_groups(group_size, in_chs)
self.stride = stride
self.dilation = dilation
self.kernel_size = kernel_size
self.in_chs = in_chs
self.out_chs = out_chs
self.num_conv_branches = num_conv_branches
# Check if SE-ReLU is requested
self.se = SqueezeExcite(out_chs) if use_se else nn.Identity()
if inference_mode:
self.reparam_conv = create_conv2d(
in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=self.groups,
bias=True,
)
else:
# Re-parameterizable skip connection
self.reparam_conv = None
self.rbr_skip = (
nn.BatchNorm2d(num_features=in_chs)
if out_chs == in_chs and stride == 1
else None
)
# Re-parameterizable conv branches
if num_conv_branches > 0:
rbr_conv = list()
for _ in range(self.num_conv_branches):
rbr_conv.append(ConvNormAct(
self.in_chs,
self.out_chs,
kernel_size=kernel_size,
stride=self.stride,
groups=self.groups,
apply_act=False,
))
self.rbr_conv = nn.ModuleList(rbr_conv)
else:
self.rbr_conv = None
# Re-parameterizable scale branch
self.rbr_scale = None
if kernel_size > 1 and use_scale_branch:
self.rbr_scale = ConvNormAct(
self.in_chs,
self.out_chs,
kernel_size=1,
stride=self.stride,
groups=self.groups,
apply_act=False
)
self.act = act_layer() if use_act else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
# Inference mode forward pass.
if self.reparam_conv is not None:
return self.act(self.se(self.reparam_conv(x)))
# Multi-branched train-time forward pass.
# Skip branch output
identity_out = 0
if self.rbr_skip is not None:
identity_out = self.rbr_skip(x)
# Scale branch output
scale_out = 0
if self.rbr_scale is not None:
scale_out = self.rbr_scale(x)
# Other branches
out = scale_out + identity_out
if self.rbr_conv is not None:
for ix in range(self.num_conv_branches):
out += self.rbr_conv[ix](x)
return self.act(self.se(out))
def reparameterize(self):
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
if self.inference_mode:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = create_conv2d(
in_channels=self.in_chs,
out_channels=self.out_chs,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
bias=True,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
# Delete un-used branches
for para in self.parameters():
para.detach_()
self.__delattr__("rbr_conv")
self.__delattr__("rbr_scale")
if hasattr(self, "rbr_skip"):
self.__delattr__("rbr_skip")
self.inference_mode = True
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
Returns:
Tuple of (kernel, bias) after fusing branches.
"""
# get weights and bias of scale branch
kernel_scale = 0
bias_scale = 0
if self.rbr_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
# Pad scale branch kernel to match conv branch kernel size.
pad = self.kernel_size // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.rbr_skip is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
# get weights and bias of conv branches
kernel_conv = 0
bias_conv = 0
if self.rbr_conv is not None:
for ix in range(self.num_conv_branches):
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
kernel_conv += _kernel
bias_conv += _bias
kernel_final = kernel_conv + kernel_scale + kernel_identity
bias_final = bias_conv + bias_scale + bias_identity
return kernel_final, bias_final
def _fuse_bn_tensor(
self, branch: Union[nn.Sequential, nn.BatchNorm2d]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to fuse batchnorm layer with preceeding conv layer.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
Args:
branch: Sequence of ops to be fused.
Returns:
Tuple of (kernel, bias) after fusing batchnorm.
"""
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, "id_tensor"):
input_dim = self.in_chs // self.groups
kernel_value = torch.zeros(
(self.in_chs, input_dim, self.kernel_size, self.kernel_size),
dtype=branch.weight.dtype,
device=branch.weight.device,
)
for i in range(self.in_chs):
kernel_value[
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class ReparamLargeKernelConv(nn.Module):
@ -32,10 +278,10 @@ class ReparamLargeKernelConv(nn.Module):
out_chs: int,
kernel_size: int,
stride: int,
groups: int,
small_kernel: int,
group_size: int,
small_kernel: Optional[int] = None,
inference_mode: bool = False,
act_layer: nn.Module = nn.GELU(),
act_layer: Optional[nn.Module] = None,
) -> None:
"""Construct a ReparamLargeKernelConv module.
@ -44,55 +290,63 @@ class ReparamLargeKernelConv(nn.Module):
out_chs: Number of output channels.
kernel_size: Kernel size of the large kernel conv branch.
stride: Stride size. Default: 1
groups: Group number. Default: 1
group_size: Group size. Default: 1
small_kernel: Kernel size of small kernel conv branch.
inference_mode: If True, instantiates model in inference mode. Default: ``False``
act_layer: Activation module. Default: ``nn.GELU``
"""
super(ReparamLargeKernelConv, self).__init__()
self.stride = stride
self.groups = groups
self.groups = num_groups(group_size, in_chs)
self.in_chs = in_chs
self.out_chs = out_chs
self.act_layer = act_layer
self.kernel_size = kernel_size
self.small_kernel = small_kernel
self.padding = kernel_size // 2
if inference_mode:
self.lkb_reparam = nn.Conv2d(
in_chs=in_chs,
out_chs=out_chs,
self.lkb_reparam = create_conv2d(
in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
dilation=1,
groups=groups,
groups=self.groups,
bias=True,
)
else:
self.lkb_origin = self._conv_bn(
kernel_size=kernel_size, padding=self.padding
self.lkb_reparam = None
self.lkb_origin = ConvNormAct(
in_chs,
out_chs,
kernel_size=kernel_size,
stride=self.stride,
groups=self.groups,
apply_act=False,
)
if small_kernel is not None:
assert (
small_kernel <= kernel_size
), "The kernel size for re-param cannot be larger than the large kernel!"
self.small_conv = self._conv_bn(
kernel_size=small_kernel, padding=small_kernel // 2
self.small_conv = ConvNormAct(
in_chs,
out_chs,
kernel_size=small_kernel,
stride=self.stride,
groups=self.groups,
apply_act=False,
)
# FIXME output of this act was not used in original impl, likely due to bug
self.act = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
if hasattr(self, "lkb_reparam"):
if self.lkb_reparam is not None:
out = self.lkb_reparam(x)
else:
out = self.lkb_origin(x)
if hasattr(self, "small_conv"):
out += self.small_conv(x)
self.act_layer(out)
if self.small_conv is not None:
out = out + self.small_conv(x)
out = self.act(out)
return out
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
@ -119,12 +373,11 @@ class ReparamLargeKernelConv(nn.Module):
for inference.
"""
eq_k, eq_b = self.get_kernel_bias()
self.lkb_reparam = nn.Conv2d(
in_chs=self.in_chs,
out_chs=self.out_chs,
self.lkb_reparam = create_conv2d(
self.in_chs,
self.out_chs,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.lkb_origin.conv.dilation,
groups=self.groups,
bias=True,
@ -159,35 +412,11 @@ class ReparamLargeKernelConv(nn.Module):
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
"""Helper method to construct conv-batchnorm layers.
Args:
kernel_size: Size of the convolution kernel.
padding: Zero-padding size.
Returns:
A nn.Sequential Conv-BN module.
"""
mod_list = nn.Sequential()
mod_list.add_module(
"conv",
nn.Conv2d(
in_chs=self.in_chs,
out_chs=self.out_chs,
kernel_size=kernel_size,
stride=self.stride,
padding=padding,
groups=self.groups,
bias=False,
),
)
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_chs))
return mod_list
def convolutional_stem(
in_chs: int, out_chs: int, inference_mode: bool = False
in_chs: int,
out_chs: int,
inference_mode: bool = False
) -> nn.Sequential:
"""Build convolutional stem with MobileOne blocks.
@ -206,8 +435,6 @@ def convolutional_stem(
kernel_size=3,
stride=2,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
),
MobileOneBlock(
in_chs=out_chs,
@ -216,8 +443,6 @@ def convolutional_stem(
stride=2,
group_size=1,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
),
MobileOneBlock(
in_chs=out_chs,
@ -225,8 +450,6 @@ def convolutional_stem(
kernel_size=1,
stride=1,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
),
)
@ -237,6 +460,7 @@ class Attention(nn.Module):
Source modified from:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
@ -259,7 +483,8 @@ class Attention(nn.Module):
assert dim % head_dim == 0, "dim should be divisible by head_dim"
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.scale = head_dim**-0.5
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
@ -267,11 +492,9 @@ class Attention(nn.Module):
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.shape
B, C, H, W = shape
B, C, H, W = x.shape
N = H * W
if len(shape) == 4:
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
x = x.flatten(2).transpose(-2, -1) # (B, N, C)
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
@ -279,16 +502,22 @@ class Attention(nn.Module):
)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# trick here to make q@k.t more stable
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
if len(shape) == 4:
x = x.transpose(-2, -1).reshape(B, C, H, W)
x = x.transpose(-2, -1).reshape(B, C, H, W)
return x
@ -314,32 +543,25 @@ class PatchEmbed(nn.Module):
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
"""
super().__init__()
block = list()
block.append(
self.proj = nn.Sequential(
ReparamLargeKernelConv(
in_chs=in_chs,
out_chs=embed_dim,
kernel_size=patch_size,
stride=stride,
groups=in_chs,
group_size=1,
small_kernel=3,
inference_mode=inference_mode,
)
)
block.append(
act_layer=None, # activation was not used in original impl
),
MobileOneBlock(
in_chs=embed_dim,
out_chs=embed_dim,
kernel_size=1,
stride=1,
padding=0,
groups=1,
inference_mode=inference_mode,
use_se=False,
num_conv_branches=1,
)
)
self.proj = nn.Sequential(*block)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
@ -377,8 +599,8 @@ class RepMixer(nn.Module):
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_chs=self.dim,
out_chs=self.dim,
self.dim,
self.dim,
kernel_size=self.kernel_size,
stride=1,
padding=self.kernel_size // 2,
@ -386,6 +608,7 @@ class RepMixer(nn.Module):
bias=True,
)
else:
self.reparam_conv = None
self.norm = MobileOneBlock(
dim,
dim,
@ -404,12 +627,10 @@ class RepMixer(nn.Module):
)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "reparam_conv"):
if self.reparam_conv is not None:
x = self.reparam_conv(x)
return x
else:
@ -444,12 +665,11 @@ class RepMixer(nn.Module):
)
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
self.reparam_conv = nn.Conv2d(
in_chs=self.dim,
out_chs=self.dim,
self.reparam_conv = create_conv2d(
self.dim,
self.dim,
kernel_size=self.kernel_size,
stride=1,
padding=self.kernel_size // 2,
groups=self.dim,
bias=True,
)
@ -487,19 +707,26 @@ class ConvMlp(nn.Module):
super().__init__()
out_chs = out_chs or in_chs
hidden_channels = hidden_channels or in_chs
self.conv = nn.Sequential()
self.conv.add_module(
"conv",
nn.Conv2d(
in_chs=in_chs,
out_chs=out_chs,
kernel_size=7,
padding=3,
groups=in_chs,
bias=False,
),
# self.conv = nn.Sequential()
# self.conv.add_module(
# "conv",
# nn.Conv2d(
# in_chs,
# out_chs,
# kernel_size=7,
# padding=3,
# groups=in_chs,
# bias=False,
# ),
# )
# self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
self.conv = ConvNormAct(
in_chs,
out_chs,
kernel_size=7,
groups=in_chs,
apply_act=False,
)
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1)
@ -565,27 +792,28 @@ class RepCPE(nn.Module):
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_chs=self.in_chs,
out_chs=self.embed_dim,
self.in_chs,
self.embed_dim,
kernel_size=self.spatial_shape,
stride=1,
padding=int(self.spatial_shape[0] // 2),
padding=spatial_shape[0] // 2,
groups=self.embed_dim,
bias=True,
)
else:
self.reparam_conv = None
self.pe = nn.Conv2d(
in_chs,
embed_dim,
spatial_shape,
1,
int(spatial_shape[0] // 2),
bias=True,
groups=embed_dim,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "reparam_conv"):
if self.reparam_conv is not None:
x = self.reparam_conv(x)
return x
else:
@ -620,8 +848,8 @@ class RepCPE(nn.Module):
# Introduce reparam conv
self.reparam_conv = nn.Conv2d(
in_chs=self.in_chs,
out_chs=self.embed_dim,
self.in_chs,
self.embed_dim,
kernel_size=self.spatial_shape,
stride=1,
padding=int(self.spatial_shape[0] // 2),
@ -682,10 +910,9 @@ class RepMixerBlock(nn.Module):
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
mlp_ratio
)
mlp_hidden_dim = int(dim * mlp_ratio)
self.convffn = ConvMlp(
in_chs=dim,
hidden_channels=mlp_hidden_dim,
hidden_channels=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
@ -696,9 +923,7 @@ class RepMixerBlock(nn.Module):
# Layer Scale
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
def forward(self, x):
if self.use_layer_scale:
@ -763,12 +988,8 @@ class AttentionBlock(nn.Module):
# Layer Scale
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
def forward(self, x):
if self.use_layer_scale:
@ -823,32 +1044,28 @@ def basic_blocks(
/ (sum(num_blocks) - 1)
)
if token_mixer_type == "repmixer":
blocks.append(
RepMixerBlock(
dim,
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
)
)
blocks.append(RepMixerBlock(
dim,
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
))
elif token_mixer_type == "attention":
blocks.append(
AttentionBlock(
dim,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
)
)
blocks.append(AttentionBlock(
dim,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
else:
raise ValueError(
"Token mixer type: {} not supported".format(token_mixer_type)
@ -932,15 +1149,13 @@ class FastVit(nn.Module):
# Patch merging/downsampling between stages.
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
network.append(
PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_chs=embed_dims[i],
embed_dim=embed_dims[i + 1],
inference_mode=inference_mode,
)
)
network += [PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_chs=embed_dims[i],
embed_dim=embed_dims[i + 1],
inference_mode=inference_mode,
)]
self.network = nn.ModuleList(network)
@ -1054,6 +1269,8 @@ default_cfgs = {
"fastvit_t": _cfg(crop_pct=0.9),
"fastvit_s": _cfg(crop_pct=0.9),
"fastvit_m": _cfg(crop_pct=0.95),
'fastvit_t8': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar')
}