MobileOne and FastViT weights on HF hub, more code cleanup and tweaks, features_only working. Add reparam flag to validate and benchmark, support reparm of all models with fuse(), reparameterize() or switch_to_deploy() methods on modules

This commit is contained in:
Ross Wightman 2023-08-23 14:16:43 -07:00 committed by Ross Wightman
parent 40dbaafef5
commit 5242ba6edc
8 changed files with 447 additions and 304 deletions

View File

@ -22,7 +22,8 @@ from timm.data import resolve_data_config
from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
reparameterize_model
has_apex = False
try:
@ -116,6 +117,8 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--reparam', default=False, action='store_true',
help='Reparameterize model')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
# codegen (model compilation) options
@ -222,6 +225,7 @@ class BenchmarkRunner:
torchscript=False,
torchcompile=None,
aot_autograd=False,
reparam=False,
precision='float32',
fuser='',
num_warm_iter=10,
@ -252,10 +256,13 @@ class BenchmarkRunner:
drop_block_rate=kwargs.pop('drop_block', None),
**kwargs.pop('model_kwargs', {}),
)
if reparam:
self.model = reparameterize_model(self.model)
self.model.to(
device=self.device,
dtype=self.model_dtype,
memory_format=torch.channels_last if self.channels_last else None)
memory_format=torch.channels_last if self.channels_last else None,
)
self.num_classes = self.model.num_classes
self.param_count = count_params(self.model)
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))

View File

@ -12,6 +12,10 @@ RepVGG - repvgg_*
Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT
MobileOne - mobileone_*
Paper: `MobileOne: An Improved One millisecond Mobile Backbone` - https://arxiv.org/abs/2206.04040
Code and weights: https://github.com/apple/ml-mobileone, licensed MIT
In all cases the models have been modified to fit within the design of ByobNet. I've remapped
the original weights and verified accuracies.
@ -468,8 +472,6 @@ class RepVggBlock(nn.Module):
""" RepVGG Block.
Adapted from impl at https://github.com/DingXiaoH/RepVGG
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
"""
def __init__(
@ -485,11 +487,24 @@ class RepVggBlock(nn.Module):
layers: LayerFn = None,
drop_block: Callable = None,
drop_path_rate: float = 0.,
inference_mode: bool = False
):
super(RepVggBlock, self).__init__()
self.groups = groups = num_groups(group_size, in_chs)
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
#self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME temp for remapping
if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=in_chs,
out_channels=out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=True,
)
else:
self.reparam_conv = None
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
self.conv_kxk = layers.conv_norm_act(
@ -497,8 +512,9 @@ class RepVggBlock(nn.Module):
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
)
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.act = layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
@ -511,16 +527,109 @@ class RepVggBlock(nn.Module):
self.attn.reset_parameters()
def forward(self, x):
if self.reparam_conv is not None:
return self.act(self.attn(self.reparam_conv(x)))
if self.identity is None:
x = self.conv_1x1(x) + self.conv_kxk(x)
else:
identity = self.identity(x)
x = self.conv_1x1(x) + self.conv_kxk(x)
x = self.drop_path(x) # not in the paper / official impl, experimental
x = x + identity
x += identity
x = self.attn(x) # no attn in the paper / official impl, experimental
return self.act(x)
def reparameterize(self):
""" Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
if self.reparam_conv is not None:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2d(
in_channels=self.conv_kxk.conv.in_channels,
out_channels=self.conv_kxk.conv.out_channels,
kernel_size=self.conv_kxk.conv.kernel_size,
stride=self.conv_kxk.conv.stride,
padding=self.conv_kxk.conv.padding,
dilation=self.conv_kxk.conv.dilation,
groups=self.conv_kxk.conv.groups,
bias=True,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
# Delete un-used branches
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__('conv_kxk')
self.__delattr__('conv_1x1')
self.__delattr__('identity')
self.__delattr__('drop_path')
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
""" Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
"""
# get weights and bias of scale branch
kernel_1x1 = 0
bias_1x1 = 0
if self.conv_1x1 is not None:
kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1)
# Pad scale branch kernel to match conv branch kernel size.
pad = self.conv_kxk.conv.kernel_size[0] // 2
kernel_1x1 = torch.nn.functional.pad(kernel_1x1, [pad, pad, pad, pad])
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.identity is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
# get weights and bias of conv branches
kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk)
kernel_final = kernel_conv + kernel_1x1 + kernel_identity
bias_final = bias_conv + bias_1x1 + bias_identity
return kernel_final, bias_final
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
""" Method to fuse batchnorm layer with preceeding conv layer.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
"""
if isinstance(branch, ConvNormAct):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
in_chs = self.conv_kxk.conv.in_channels
input_dim = in_chs // self.groups
kernel_size = self.conv_kxk.conv.kernel_size
kernel_value = torch.zeros_like(self.conv_kxk.conv.weight)
for i in range(in_chs):
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class MobileOneBlock(nn.Module):
""" MobileOne building block.
@ -549,28 +658,11 @@ class MobileOneBlock(nn.Module):
drop_path_rate: float = 0.,
) -> None:
""" Construct a MobileOneBlock module.
:param in_chs: Number of channels in the input.
:param out_chs: Number of channels produced by the block.
:param kernel_size: Size of the convolution kernel.
:param stride: Stride size.
:param dilation: Kernel dilation factor.
:param groups: Group number.
:param inference_mode: If True, instantiates model in inference mode.
:param use_se: Whether to use SE-ReLU activations.
:param num_conv_branches: Number of linear conv branches.
"""
super(MobileOneBlock, self).__init__()
self.stride = stride
self.kernel_size = kernel_size
self.in_channels = in_chs
self.out_channels = out_chs
self.num_conv_branches = num_conv_branches
self.groups = groups = num_groups(group_size, in_chs)
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
# Check if SE-ReLU is requested
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME move after remap
if inference_mode:
self.reparam_conv = nn.Conv2d(
@ -602,7 +694,9 @@ class MobileOneBlock(nn.Module):
self.conv_scale = layers.conv_norm_act(
in_chs, out_chs, kernel_size=1,
stride=stride, groups=groups, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.act = layers.act(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -623,9 +717,11 @@ class MobileOneBlock(nn.Module):
scale_out = self.conv_scale(x)
# Other branches
out = scale_out + identity_out
out = scale_out
for ck in self.conv_kxk:
out += ck(x)
out = self.drop_path(out)
out += identity_out
return self.act(self.attn(out))
@ -652,18 +748,18 @@ class MobileOneBlock(nn.Module):
self.reparam_conv.bias.data = bias
# Delete un-used branches
for para in self.parameters():
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__('conv_kxk')
self.__delattr__('conv_scale')
if hasattr(self, 'identity'):
self.__delattr__('identity')
self.__delattr__('drop_path')
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
""" Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
:return: Tuple of (kernel, bias) after fusing branches.
"""
# get weights and bias of scale branch
kernel_scale = 0
@ -671,7 +767,7 @@ class MobileOneBlock(nn.Module):
if self.conv_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
# Pad scale branch kernel to match conv branch kernel size.
pad = self.kernel_size // 2
pad = self.conv_kxk[0].conv.kernel_size[0] // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
# get weights and bias of skip branch
@ -695,9 +791,6 @@ class MobileOneBlock(nn.Module):
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
""" Method to fuse batchnorm layer with preceeding conv layer.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
:param branch:
:return: Tuple of (kernel, bias) after fusing batchnorm.
"""
if isinstance(branch, ConvNormAct):
kernel = branch.conv.weight
@ -709,16 +802,12 @@ class MobileOneBlock(nn.Module):
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = torch.zeros(
(self.in_channels,
input_dim,
self.kernel_size,
self.kernel_size),
dtype=branch.weight.dtype,
device=branch.weight.device)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
in_chs = self.conv_kxk[0].conv.in_channels
input_dim = in_chs // self.groups
kernel_size = self.conv_kxk[0].conv.kernel_size
kernel_value = torch.zeros_like(self.conv_kxk[0].conv.weight)
for i in range(in_chs):
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch.running_mean
@ -1226,6 +1315,16 @@ model_cfgs = dict(
num_features=1920,
),
repvgg_a0=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(0.75, 0.75, 0.75, 2.5)),
stem_type='rep',
stem_chs=48,
),
repvgg_a1=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1, 1, 1, 2.5)),
stem_type='rep',
stem_chs=64,
),
repvgg_a2=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
stem_type='rep',
@ -1643,15 +1742,11 @@ model_cfgs = dict(
),
)
# FIXME temporary for mobileone remap
from .fastvit import checkpoint_filter_fn
def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@ -1683,6 +1778,12 @@ default_cfgs = generate_default_cfgs({
'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)),
# RepVGG weights
'repvgg_a0.rvgg_in1k': _cfg(
hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
'repvgg_a1.rvgg_in1k': _cfg(
hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
'repvgg_a2.rvgg_in1k': _cfg(
hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
@ -1707,6 +1808,11 @@ default_cfgs = generate_default_cfgs({
'repvgg_b3g4.rvgg_in1k': _cfg(
hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
'repvgg_d2se.rvgg_in1k': _cfg(
hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit',
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
),
# experimental ResNet configs
'resnet51q.ra2_in1k': _cfg(
@ -1810,24 +1916,24 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
'mobileone_s0': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar',
'mobileone_s0.apple_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.875,
),
'mobileone_s1': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar',
'mobileone_s1.apple_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.9,
),
'mobileone_s2': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar',
'mobileone_s2.apple_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.9,
),
'mobileone_s3': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar',
'mobileone_s3.apple_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.9,
),
'mobileone_s4': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar',
'mobileone_s4.apple_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.9,
),
})
@ -1857,6 +1963,22 @@ def gernet_s(pretrained=False, **kwargs) -> ByobNet:
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a0(pretrained=False, **kwargs) -> ByobNet:
""" RepVGG-A0
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a0', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a1(pretrained=False, **kwargs) -> ByobNet:
""" RepVGG-A1
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a1', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a2(pretrained=False, **kwargs) -> ByobNet:
""" RepVGG-A2

View File

@ -37,8 +37,8 @@ class ConvNorm(torch.nn.Sequential):
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(
w.size(1) * self.c.groups, w.size(0), w.shape[2:],
stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
w.size(1) * self.conv.groups, w.size(0), w.shape[2:],
stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m

View File

@ -1,9 +1,8 @@
# FastViT for PyTorch
#
# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main
#
# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import copy
import os
from functools import partial
from typing import List, Tuple, Optional, Union
@ -12,8 +11,10 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
ClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -93,7 +94,7 @@ class MobileOneBlock(nn.Module):
# Re-parameterizable skip connection
self.reparam_conv = None
self.rbr_skip = (
self.identity = (
nn.BatchNorm2d(num_features=in_chs)
if out_chs == in_chs and stride == 1
else None
@ -101,7 +102,7 @@ class MobileOneBlock(nn.Module):
# Re-parameterizable conv branches
if num_conv_branches > 0:
self.rbr_conv = nn.ModuleList([
self.conv_kxk = nn.ModuleList([
ConvNormAct(
self.in_chs,
self.out_chs,
@ -112,12 +113,12 @@ class MobileOneBlock(nn.Module):
) for _ in range(self.num_conv_branches)
])
else:
self.rbr_conv = None
self.conv_kxk = None
# Re-parameterizable scale branch
self.rbr_scale = None
self.conv_scale = None
if kernel_size > 1 and use_scale_branch:
self.rbr_scale = ConvNormAct(
self.conv_scale = ConvNormAct(
self.in_chs,
self.out_chs,
kernel_size=1,
@ -135,20 +136,20 @@ class MobileOneBlock(nn.Module):
return self.act(self.se(self.reparam_conv(x)))
# Multi-branched train-time forward pass.
# Skip branch output
# Identity branch output
identity_out = 0
if self.rbr_skip is not None:
identity_out = self.rbr_skip(x)
if self.identity is not None:
identity_out = self.identity(x)
# Scale branch output
scale_out = 0
if self.rbr_scale is not None:
scale_out = self.rbr_scale(x)
if self.conv_scale is not None:
scale_out = self.conv_scale(x)
# Other branches
# Other kxk conv branches
out = scale_out + identity_out
if self.rbr_conv is not None:
for rc in self.rbr_conv:
if self.conv_kxk is not None:
for rc in self.conv_kxk:
out += rc(x)
return self.act(self.se(out))
@ -176,13 +177,15 @@ class MobileOneBlock(nn.Module):
self.reparam_conv.bias.data = bias
# Delete un-used branches
for para in self.parameters():
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__("rbr_conv")
self.__delattr__("rbr_scale")
if hasattr(self, "rbr_skip"):
self.__delattr__("rbr_skip")
self.__delattr__("conv_kxk")
self.__delattr__("conv_scale")
if hasattr(self, "identity"):
self.__delattr__("identity")
self.inference_mode = True
@ -196,8 +199,8 @@ class MobileOneBlock(nn.Module):
# get weights and bias of scale branch
kernel_scale = 0
bias_scale = 0
if self.rbr_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
if self.conv_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
# Pad scale branch kernel to match conv branch kernel size.
pad = self.kernel_size // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
@ -205,15 +208,15 @@ class MobileOneBlock(nn.Module):
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.rbr_skip is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
if self.identity is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
# get weights and bias of conv branches
kernel_conv = 0
bias_conv = 0
if self.rbr_conv is not None:
if self.conv_kxk is not None:
for ix in range(self.num_conv_branches):
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
_kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
kernel_conv += _kernel
bias_conv += _bias
@ -233,7 +236,7 @@ class MobileOneBlock(nn.Module):
Returns:
Tuple of (kernel, bias) after fusing batchnorm.
"""
if isinstance(branch, nn.Sequential):
if isinstance(branch, ConvNormAct):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
@ -306,7 +309,7 @@ class ReparamLargeKernelConv(nn.Module):
self.kernel_size = kernel_size
self.small_kernel = small_kernel
if inference_mode:
self.lkb_reparam = create_conv2d(
self.reparam_conv = create_conv2d(
in_chs,
out_chs,
kernel_size=kernel_size,
@ -316,8 +319,8 @@ class ReparamLargeKernelConv(nn.Module):
bias=True,
)
else:
self.lkb_reparam = None
self.lkb_origin = ConvNormAct(
self.reparam_conv = None
self.large_conv = ConvNormAct(
in_chs,
out_chs,
kernel_size=kernel_size,
@ -341,10 +344,10 @@ class ReparamLargeKernelConv(nn.Module):
self.act = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.lkb_reparam is not None:
out = self.lkb_reparam(x)
if self.reparam_conv is not None:
out = self.reparam_conv(x)
else:
out = self.lkb_origin(x)
out = self.large_conv(x)
if self.small_conv is not None:
out = out + self.small_conv(x)
out = self.act(out)
@ -357,7 +360,7 @@ class ReparamLargeKernelConv(nn.Module):
Returns:
Tuple of (kernel, bias) after fusing branches.
"""
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn)
if hasattr(self, "small_conv"):
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
eq_b += small_b
@ -374,19 +377,18 @@ class ReparamLargeKernelConv(nn.Module):
for inference.
"""
eq_k, eq_b = self.get_kernel_bias()
self.lkb_reparam = create_conv2d(
self.reparam_conv = create_conv2d(
self.in_chs,
self.out_chs,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.lkb_origin.conv.dilation,
groups=self.groups,
bias=True,
)
self.lkb_reparam.weight.data = eq_k
self.lkb_reparam.bias.data = eq_b
self.__delattr__("lkb_origin")
self.reparam_conv.weight.data = eq_k
self.reparam_conv.bias.data = eq_b
self.__delattr__("large_conv")
if hasattr(self, "small_conv"):
self.__delattr__("small_conv")
@ -532,6 +534,8 @@ class PatchEmbed(nn.Module):
stride: int,
in_chs: int,
embed_dim: int,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False,
) -> None:
"""Build patch embedding layer.
@ -553,13 +557,14 @@ class PatchEmbed(nn.Module):
group_size=1,
small_kernel=3,
inference_mode=inference_mode,
act_layer=None, # activation was not used in original impl
act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
),
MobileOneBlock(
in_chs=embed_dim,
out_chs=embed_dim,
kernel_size=1,
stride=1,
act_layer=act_layer,
inference_mode=inference_mode,
)
)
@ -569,6 +574,16 @@ class PatchEmbed(nn.Module):
return x
class LayerScale2d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class RepMixer(nn.Module):
"""Reparameterizable token mixer.
@ -580,7 +595,6 @@ class RepMixer(nn.Module):
self,
dim,
kernel_size=3,
use_layer_scale=True,
layer_scale_init_value=1e-5,
inference_mode: bool = False,
):
@ -589,7 +603,6 @@ class RepMixer(nn.Module):
Args:
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
kernel_size: Kernel size for spatial mixing. Default: 3
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
inference_mode: If True, instantiates model in inference mode. Default: ``False``
"""
@ -626,19 +639,16 @@ class RepMixer(nn.Module):
group_size=1,
use_act=False,
)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
if layer_scale_init_value is not None:
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale = nn.Identity
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None:
x = self.reparam_conv(x)
return x
else:
if self.use_layer_scale:
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
else:
x = x + self.mixer(x) - self.norm(x)
x = x + self.layer_scale(self.mixer(x) - self.norm(x))
return x
def reparameterize(self) -> None:
@ -651,11 +661,11 @@ class RepMixer(nn.Module):
self.mixer.reparameterize()
self.norm.reparameterize()
if self.use_layer_scale:
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
if isinstance(self.layer_scale, LayerScale2d):
w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * (
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
)
b = torch.squeeze(self.layer_scale) * (
b = torch.squeeze(self.layer_scale.gamma) * (
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
)
else:
@ -677,11 +687,12 @@ class RepMixer(nn.Module):
self.reparam_conv.weight.data = w
self.reparam_conv.bias.data = b
for para in self.parameters():
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__("mixer")
self.__delattr__("norm")
if self.use_layer_scale:
self.__delattr__("layer_scale")
@ -708,19 +719,6 @@ class ConvMlp(nn.Module):
super().__init__()
out_chs = out_chs or in_chs
hidden_channels = hidden_channels or in_chs
# self.conv = nn.Sequential()
# self.conv.add_module(
# "conv",
# nn.Conv2d(
# in_chs,
# out_chs,
# kernel_size=7,
# padding=3,
# groups=in_chs,
# bias=False,
# ),
# )
# self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
self.conv = ConvNormAct(
in_chs,
out_chs,
@ -750,7 +748,7 @@ class ConvMlp(nn.Module):
return x
class RepCPE(nn.Module):
class RepConditionalPosEnc(nn.Module):
"""Implementation of conditional positional encoding.
For more details refer to paper:
@ -774,7 +772,7 @@ class RepCPE(nn.Module):
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
super(RepCPE, self).__init__()
super(RepConditionalPosEnc, self).__init__()
if isinstance(spatial_shape, int):
spatial_shape = tuple([spatial_shape] * 2)
assert isinstance(spatial_shape, Tuple), (
@ -803,7 +801,7 @@ class RepCPE(nn.Module):
)
else:
self.reparam_conv = None
self.pe = nn.Conv2d(
self.pos_enc = nn.Conv2d(
self.dim,
self.dim_out,
spatial_shape,
@ -816,9 +814,8 @@ class RepCPE(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None:
x = self.reparam_conv(x)
return x
else:
x = self.pe(x) + x
x = self.pos_enc(x) + x
return x
def reparameterize(self) -> None:
@ -831,8 +828,8 @@ class RepCPE(nn.Module):
self.spatial_shape[0],
self.spatial_shape[1],
),
dtype=self.pe.weight.dtype,
device=self.pe.weight.device,
dtype=self.pos_enc.weight.dtype,
device=self.pos_enc.weight.device,
)
for i in range(self.dim):
kernel_value[
@ -844,8 +841,8 @@ class RepCPE(nn.Module):
id_tensor = kernel_value
# Reparameterize Id tensor and conv
w_final = id_tensor + self.pe.weight
b_final = self.pe.bias
w_final = id_tensor + self.pos_enc.weight
b_final = self.pos_enc.bias
# Introduce reparam conv
self.reparam_conv = nn.Conv2d(
@ -860,9 +857,11 @@ class RepCPE(nn.Module):
self.reparam_conv.weight.data = w_final
self.reparam_conv.bias.data = b_final
for para in self.parameters():
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__("pe")
self.__delattr__("pos_enc")
class RepMixerBlock(nn.Module):
@ -878,9 +877,8 @@ class RepMixerBlock(nn.Module):
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
proj_drop: float = 0.0,
drop_path: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
inference_mode: bool = False,
):
@ -891,9 +889,8 @@ class RepMixerBlock(nn.Module):
kernel_size: Kernel size for repmixer. Default: 3
mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU``
drop: Dropout rate. Default: 0.0
proj_drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0
use_layer_scale: Flag to turn on layer scale. Default: ``True``
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
@ -903,36 +900,25 @@ class RepMixerBlock(nn.Module):
self.token_mixer = RepMixer(
dim,
kernel_size=kernel_size,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
)
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
mlp_ratio
)
self.convffn = ConvMlp(
self.mlp = ConvMlp(
in_chs=dim,
hidden_channels=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)
# Drop Path
if layer_scale_init_value is not None:
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale = nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# Layer Scale
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
def forward(self, x):
if self.use_layer_scale:
x = self.token_mixer(x)
x = x + self.drop_path(self.layer_scale * self.convffn(x))
else:
x = self.token_mixer(x)
x = x + self.drop_path(self.convffn(x))
x = x + self.drop_path(self.layer_scale(self.mlp(x)))
return x
@ -949,9 +935,8 @@ class AttentionBlock(nn.Module):
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
drop: float = 0.0,
proj_drop: float = 0.0,
drop_path: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
):
"""Build Attention Block.
@ -961,9 +946,8 @@ class AttentionBlock(nn.Module):
mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU``
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
drop: Dropout rate. Default: 0.0
proj_drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0
use_layer_scale: Flag to turn on layer scale. Default: ``True``
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
"""
@ -971,34 +955,27 @@ class AttentionBlock(nn.Module):
self.norm = norm_layer(dim)
self.token_mixer = Attention(dim=dim)
if layer_scale_init_value is not None:
self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale_1 = nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
mlp_ratio
)
mlp_hidden_dim = int(dim * mlp_ratio)
self.convffn = ConvMlp(
self.mlp = ConvMlp(
in_chs=dim,
hidden_channels=mlp_hidden_dim,
hidden_channels=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)
# Drop path
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# Layer Scale
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
if layer_scale_init_value is not None:
self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale_2 = nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
else:
x = x + self.drop_path(self.token_mixer(self.norm(x)))
x = x + self.drop_path(self.convffn(x))
x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x))))
x = x + self.drop_path2(self.layer_scale_2(self.mlp(x)))
return x
@ -1017,35 +994,38 @@ class FastVitStage(nn.Module):
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
layer_scale_init_value: Optional[float] = 1e-5,
lkc_use_act=False,
inference_mode=False,
):
"""FastViT stage.
Args:
dim: Number of embedding dimensions.
num_blocks: List containing number of blocks per stage.
depth: Number of blocks in stage
token_mixer_type: Token mixer type.
kernel_size: Kernel size for repmixer.
mlp_ratio: MLP expansion ratio.
act_layer: Activation layer.
norm_layer: Normalization layer.
drop_rate: Dropout rate.
proj_drop_rate: Dropout rate.
drop_path_rate: Drop path rate.
use_layer_scale: Flag to turn on layer scale regularization.
layer_scale_init_value: Layer scale value at initialization.
inference_mode: Flag to instantiate block in inference mode.
"""
super().__init__()
self.grad_checkpointing = False
if downsample:
self.downsample = PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_chs=dim,
embed_dim=dim_out,
act_layer=act_layer,
lkc_use_act=lkc_use_act,
inference_mode=inference_mode,
)
else:
@ -1065,9 +1045,8 @@ class FastVitStage(nn.Module):
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
drop=drop_rate,
proj_drop=proj_drop_rate,
drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
))
@ -1077,9 +1056,8 @@ class FastVitStage(nn.Module):
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
proj_drop=proj_drop_rate,
drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
else:
@ -1091,6 +1069,9 @@ class FastVitStage(nn.Module):
def forward(self, x):
x = self.downsample(x)
x = self.pos_emb(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
@ -1116,21 +1097,25 @@ class FastVit(nn.Module):
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False,
) -> None:
super().__init__()
self.num_classes = 0 if fork_feat else num_classes
self.fork_feat = fork_feat
self.global_pool = global_pool
self.feature_info = []
# Convolutional stem
self.patch_embed = convolutional_stem(
self.stem = convolutional_stem(
in_chans,
embed_dims[0],
inference_mode,
@ -1138,14 +1123,16 @@ class FastVit(nn.Module):
# Build the main stages of the network architecture
prev_dim = embed_dims[0]
scale = 1
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
network = []
stages = []
for i in range(len(layers)):
downsample = downsamples[i] or prev_dim != embed_dims[i]
stage = FastVitStage(
dim=prev_dim,
dim_out=embed_dims[i],
depth=layers[i],
downsample=downsamples[i] or prev_dim != embed_dims[i],
downsample=downsample,
down_patch_size=down_patch_size,
down_stride=down_stride,
pos_emb_layer=pos_embs[i],
@ -1154,16 +1141,19 @@ class FastVit(nn.Module):
mlp_ratio=mlp_ratios[i],
act_layer=act_layer,
norm_layer=norm_layer,
drop_rate=drop_rate,
proj_drop_rate=drop_rate,
drop_path_rate=dpr[i],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
lkc_use_act=lkc_use_act,
inference_mode=inference_mode,
)
network.append(stage)
stages.append(stage)
prev_dim = embed_dims[i]
self.network = nn.Sequential(*network)
if downsample:
scale *= 2
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = prev_dim
# For segmentation and detection, extract intermediate output
if self.fork_feat:
@ -1181,10 +1171,10 @@ class FastVit(nn.Module):
self.add_module(layer_name, layer)
else:
# Classifier head
self.gap = nn.AdaptiveAvgPool2d(output_size=1)
self.conv_exp = MobileOneBlock(
final_features = int(embed_dims[-1] * cls_ratio)
self.final_conv = MobileOneBlock(
in_chs=embed_dims[-1],
out_chs=int(embed_dims[-1] * cls_ratio),
out_chs=final_features,
kernel_size=3,
stride=1,
group_size=1,
@ -1192,10 +1182,12 @@ class FastVit(nn.Module):
use_se=True,
num_conv_branches=1,
)
self.head = (
nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
if num_classes > 0
else nn.Identity()
self.num_features = final_features
self.head = ClassifierHead(
final_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
self.apply(self._init_weights)
@ -1207,23 +1199,39 @@ class FastVit(nn.Module):
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@staticmethod
def _scrub_checkpoint(checkpoint, model):
sterile_dict = {}
for k1, v1 in checkpoint.items():
if k1 not in model.state_dict():
continue
if v1.shape == model.state_dict()[k1].shape:
sterile_dict[k1] = v1
return sterile_dict
@torch.jit.ignore
def no_weight_decay(self):
return set()
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
return x
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+).pos_emb', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
# input embedding
x = self.stem(x)
outs = []
for idx, block in enumerate(self.network):
for idx, block in enumerate(self.stages):
x = block(x)
if self.fork_feat:
if idx in self.out_indices:
@ -1236,20 +1244,16 @@ class FastVit(nn.Module):
# output only the features of last layer for image classification
return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
x = self.final_conv(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# input embedding
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
x = self.forward_features(x)
if self.fork_feat:
# output features of four stages for dense prediction
return x
# for image classification
x = self.conv_exp(x)
x = self.gap(x)
x = x.view(x.size(0), -1)
cls_out = self.head(x)
return cls_out
x = self.forward_head(x)
return x
def _cfg(url="", **kwargs):
@ -1257,81 +1261,63 @@ def _cfg(url="", **kwargs):
"url": url,
"num_classes": 1000,
"input_size": (3, 256, 256),
"pool_size": None,
"pool_size": (8, 8),
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
"classifier": "head",
"classifier": "head.fc",
**kwargs,
}
default_cfgs = generate_default_cfgs({
"fastvit_t8.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar'
),
hf_hub_id='timm/'),
"fastvit_t12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t12.pth.tar'
),
hf_hub_id='timm/'),
"fastvit_s12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_s12.pth.tar'),
hf_hub_id='timm/'),
"fastvit_sa12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa12.pth.tar'),
hf_hub_id='timm/'),
"fastvit_sa24.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa24.pth.tar'),
hf_hub_id='timm/'),
"fastvit_sa36.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa36.pth.tar'),
hf_hub_id='timm/'),
"fastvit_ma36.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_ma36.pth.tar',
hf_hub_id='timm/',
crop_pct=0.95
),
# "fastvit_t8.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t8.pth.tar'
# ),
# "fastvit_t12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t12.pth.tar'
# ),
#
# "fastvit_s12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12.pth.tar'),
# "fastvit_sa12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa12.pth.tar'),
# "fastvit_sa24.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa24.pth.tar'),
# "fastvit_sa36.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa36.pth.tar'),
#
# "fastvit_ma36.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_ma36.pth.tar',
# crop_pct=0.95
# ),
"fastvit_t8.apple_dist_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_t12.apple_dist_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_s12.apple_dist_in1k": _cfg(
hf_hub_id='timm/',),
"fastvit_sa12.apple_dist_in1k": _cfg(
hf_hub_id='timm/',),
"fastvit_sa24.apple_dist_in1k": _cfg(
hf_hub_id='timm/',),
"fastvit_sa36.apple_dist_in1k": _cfg(
hf_hub_id='timm/',),
"fastvit_ma36.apple_dist_in1k": _cfg(
hf_hub_id='timm/',
crop_pct=0.95
),
})
def checkpoint_filter_fn(state_dict, model):
# FIXME temporary for remapping
state_dict = state_dict.get('state_dict', state_dict)
msd = model.state_dict()
out_dict = {}
for ka, kb, va, vb in zip(msd.keys(), state_dict.keys(), msd.values(), state_dict.values()):
if va.ndim == 4 and vb.ndim == 2:
vb = vb[:, :, None, None]
out_dict[ka] = vb
return out_dict
def _create_fastvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
FastVit,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
@ -1381,7 +1367,7 @@ def fastvit_sa12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1394,7 +1380,7 @@ def fastvit_sa24(pretrained=False, **kwargs):
layers=(4, 4, 12, 4),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1407,11 +1393,12 @@ def fastvit_sa36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_ma36(pretrained=False, **kwargs):
"""Instantiate FastViT-MA36 model variant."""
@ -1419,7 +1406,7 @@ def fastvit_ma36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6),
embed_dims=(76, 152, 304, 608),
mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
)
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -126,6 +126,9 @@ class RepGhostModule(nn.Module):
self.fusion_conv = []
self.fusion_bn = []
def reparameterize(self):
self.switch_to_deploy()
class RepGhostBottleneck(nn.Module):
""" RepGhost bottleneck w/ optional SE"""

View File

@ -9,7 +9,7 @@ from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze
from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed
from .summary import update_summary, get_outdir

View File

@ -3,6 +3,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import fnmatch
from copy import deepcopy
import torch
from torchvision.ops.misc import FrozenBatchNorm2d
@ -219,3 +220,21 @@ def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
See example in docstring for `freeze`.
"""
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
def reparameterize_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
if not inplace:
model = deepcopy(model)
def _fuse(m):
for child_name, child in m.named_children():
if hasattr(child, 'fuse'):
setattr(m, child_name, child.fuse())
elif hasattr(child, "reparameterize"):
child.reparameterize()
elif hasattr(child, "switch_to_deploy"):
child.switch_to_deploy()
_fuse(child)
_fuse(model)
return model

View File

@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
from timm.layers import apply_test_time_pool, set_fast_norm
from timm.models import create_model, load_checkpoint, is_model, list_models
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
decay_batch_step, check_batch_size_retry, ParseKwargs
decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model
try:
from apex import amp
@ -125,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--reparam', default=False, action='store_true',
help='Reparameterize model')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
@ -207,6 +209,9 @@ def validate(args):
if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema)
if args.reparam:
model = reparameterize_model(model)
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))