mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
MobileOne and FastViT weights on HF hub, more code cleanup and tweaks, features_only working. Add reparam flag to validate and benchmark, support reparm of all models with fuse(), reparameterize() or switch_to_deploy() methods on modules
This commit is contained in:
parent
40dbaafef5
commit
5242ba6edc
11
benchmark.py
11
benchmark.py
@ -22,7 +22,8 @@ from timm.data import resolve_data_config
|
||||
from timm.layers import set_fast_norm
|
||||
from timm.models import create_model, is_model, list_models
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
|
||||
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
|
||||
reparameterize_model
|
||||
|
||||
has_apex = False
|
||||
try:
|
||||
@ -116,6 +117,8 @@ parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
parser.add_argument('--reparam', default=False, action='store_true',
|
||||
help='Reparameterize model')
|
||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||
|
||||
# codegen (model compilation) options
|
||||
@ -222,6 +225,7 @@ class BenchmarkRunner:
|
||||
torchscript=False,
|
||||
torchcompile=None,
|
||||
aot_autograd=False,
|
||||
reparam=False,
|
||||
precision='float32',
|
||||
fuser='',
|
||||
num_warm_iter=10,
|
||||
@ -252,10 +256,13 @@ class BenchmarkRunner:
|
||||
drop_block_rate=kwargs.pop('drop_block', None),
|
||||
**kwargs.pop('model_kwargs', {}),
|
||||
)
|
||||
if reparam:
|
||||
self.model = reparameterize_model(self.model)
|
||||
self.model.to(
|
||||
device=self.device,
|
||||
dtype=self.model_dtype,
|
||||
memory_format=torch.channels_last if self.channels_last else None)
|
||||
memory_format=torch.channels_last if self.channels_last else None,
|
||||
)
|
||||
self.num_classes = self.model.num_classes
|
||||
self.param_count = count_params(self.model)
|
||||
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
|
||||
|
@ -12,6 +12,10 @@ RepVGG - repvgg_*
|
||||
Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
||||
Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT
|
||||
|
||||
MobileOne - mobileone_*
|
||||
Paper: `MobileOne: An Improved One millisecond Mobile Backbone` - https://arxiv.org/abs/2206.04040
|
||||
Code and weights: https://github.com/apple/ml-mobileone, licensed MIT
|
||||
|
||||
In all cases the models have been modified to fit within the design of ByobNet. I've remapped
|
||||
the original weights and verified accuracies.
|
||||
|
||||
@ -468,8 +472,6 @@ class RepVggBlock(nn.Module):
|
||||
""" RepVGG Block.
|
||||
|
||||
Adapted from impl at https://github.com/DingXiaoH/RepVGG
|
||||
|
||||
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -485,11 +487,24 @@ class RepVggBlock(nn.Module):
|
||||
layers: LayerFn = None,
|
||||
drop_block: Callable = None,
|
||||
drop_path_rate: float = 0.,
|
||||
inference_mode: bool = False
|
||||
):
|
||||
super(RepVggBlock, self).__init__()
|
||||
self.groups = groups = num_groups(group_size, in_chs)
|
||||
layers = layers or LayerFn()
|
||||
groups = num_groups(group_size, in_chs)
|
||||
#self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME temp for remapping
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=in_chs,
|
||||
out_channels=out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.reparam_conv = None
|
||||
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
|
||||
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
|
||||
self.conv_kxk = layers.conv_norm_act(
|
||||
@ -497,8 +512,9 @@ class RepVggBlock(nn.Module):
|
||||
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
|
||||
)
|
||||
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
|
||||
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
||||
|
||||
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
||||
self.act = layers.act(inplace=True)
|
||||
|
||||
def init_weights(self, zero_init_last: bool = False):
|
||||
@ -511,16 +527,109 @@ class RepVggBlock(nn.Module):
|
||||
self.attn.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
if self.reparam_conv is not None:
|
||||
return self.act(self.attn(self.reparam_conv(x)))
|
||||
|
||||
if self.identity is None:
|
||||
x = self.conv_1x1(x) + self.conv_kxk(x)
|
||||
else:
|
||||
identity = self.identity(x)
|
||||
x = self.conv_1x1(x) + self.conv_kxk(x)
|
||||
x = self.drop_path(x) # not in the paper / official impl, experimental
|
||||
x = x + identity
|
||||
x += identity
|
||||
x = self.attn(x) # no attn in the paper / official impl, experimental
|
||||
return self.act(x)
|
||||
|
||||
def reparameterize(self):
|
||||
""" Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
||||
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
||||
architecture used at training time to obtain a plain CNN-like structure
|
||||
for inference.
|
||||
"""
|
||||
if self.reparam_conv is not None:
|
||||
return
|
||||
|
||||
kernel, bias = self._get_kernel_bias()
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.conv_kxk.conv.in_channels,
|
||||
out_channels=self.conv_kxk.conv.out_channels,
|
||||
kernel_size=self.conv_kxk.conv.kernel_size,
|
||||
stride=self.conv_kxk.conv.stride,
|
||||
padding=self.conv_kxk.conv.padding,
|
||||
dilation=self.conv_kxk.conv.dilation,
|
||||
groups=self.conv_kxk.conv.groups,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = kernel
|
||||
self.reparam_conv.bias.data = bias
|
||||
|
||||
# Delete un-used branches
|
||||
for name, para in self.named_parameters():
|
||||
if 'reparam_conv' in name:
|
||||
continue
|
||||
para.detach_()
|
||||
self.__delattr__('conv_kxk')
|
||||
self.__delattr__('conv_1x1')
|
||||
self.__delattr__('identity')
|
||||
self.__delattr__('drop_path')
|
||||
|
||||
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Method to obtain re-parameterized kernel and bias.
|
||||
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
|
||||
"""
|
||||
# get weights and bias of scale branch
|
||||
kernel_1x1 = 0
|
||||
bias_1x1 = 0
|
||||
if self.conv_1x1 is not None:
|
||||
kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1)
|
||||
# Pad scale branch kernel to match conv branch kernel size.
|
||||
pad = self.conv_kxk.conv.kernel_size[0] // 2
|
||||
kernel_1x1 = torch.nn.functional.pad(kernel_1x1, [pad, pad, pad, pad])
|
||||
|
||||
# get weights and bias of skip branch
|
||||
kernel_identity = 0
|
||||
bias_identity = 0
|
||||
if self.identity is not None:
|
||||
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
|
||||
|
||||
# get weights and bias of conv branches
|
||||
kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk)
|
||||
|
||||
kernel_final = kernel_conv + kernel_1x1 + kernel_identity
|
||||
bias_final = bias_conv + bias_1x1 + bias_identity
|
||||
return kernel_final, bias_final
|
||||
|
||||
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Method to fuse batchnorm layer with preceeding conv layer.
|
||||
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
|
||||
"""
|
||||
if isinstance(branch, ConvNormAct):
|
||||
kernel = branch.conv.weight
|
||||
running_mean = branch.bn.running_mean
|
||||
running_var = branch.bn.running_var
|
||||
gamma = branch.bn.weight
|
||||
beta = branch.bn.bias
|
||||
eps = branch.bn.eps
|
||||
else:
|
||||
assert isinstance(branch, nn.BatchNorm2d)
|
||||
if not hasattr(self, 'id_tensor'):
|
||||
in_chs = self.conv_kxk.conv.in_channels
|
||||
input_dim = in_chs // self.groups
|
||||
kernel_size = self.conv_kxk.conv.kernel_size
|
||||
kernel_value = torch.zeros_like(self.conv_kxk.conv.weight)
|
||||
for i in range(in_chs):
|
||||
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
|
||||
self.id_tensor = kernel_value
|
||||
kernel = self.id_tensor
|
||||
running_mean = branch.running_mean
|
||||
running_var = branch.running_var
|
||||
gamma = branch.weight
|
||||
beta = branch.bias
|
||||
eps = branch.eps
|
||||
std = (running_var + eps).sqrt()
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
|
||||
class MobileOneBlock(nn.Module):
|
||||
""" MobileOne building block.
|
||||
@ -549,28 +658,11 @@ class MobileOneBlock(nn.Module):
|
||||
drop_path_rate: float = 0.,
|
||||
) -> None:
|
||||
""" Construct a MobileOneBlock module.
|
||||
|
||||
:param in_chs: Number of channels in the input.
|
||||
:param out_chs: Number of channels produced by the block.
|
||||
:param kernel_size: Size of the convolution kernel.
|
||||
:param stride: Stride size.
|
||||
:param dilation: Kernel dilation factor.
|
||||
:param groups: Group number.
|
||||
:param inference_mode: If True, instantiates model in inference mode.
|
||||
:param use_se: Whether to use SE-ReLU activations.
|
||||
:param num_conv_branches: Number of linear conv branches.
|
||||
"""
|
||||
super(MobileOneBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channels = in_chs
|
||||
self.out_channels = out_chs
|
||||
self.num_conv_branches = num_conv_branches
|
||||
self.groups = groups = num_groups(group_size, in_chs)
|
||||
layers = layers or LayerFn()
|
||||
groups = num_groups(group_size, in_chs)
|
||||
|
||||
# Check if SE-ReLU is requested
|
||||
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME move after remap
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
@ -602,7 +694,9 @@ class MobileOneBlock(nn.Module):
|
||||
self.conv_scale = layers.conv_norm_act(
|
||||
in_chs, out_chs, kernel_size=1,
|
||||
stride=stride, groups=groups, apply_act=False)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
||||
|
||||
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
||||
self.act = layers.act(inplace=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -623,9 +717,11 @@ class MobileOneBlock(nn.Module):
|
||||
scale_out = self.conv_scale(x)
|
||||
|
||||
# Other branches
|
||||
out = scale_out + identity_out
|
||||
out = scale_out
|
||||
for ck in self.conv_kxk:
|
||||
out += ck(x)
|
||||
out = self.drop_path(out)
|
||||
out += identity_out
|
||||
|
||||
return self.act(self.attn(out))
|
||||
|
||||
@ -652,18 +748,18 @@ class MobileOneBlock(nn.Module):
|
||||
self.reparam_conv.bias.data = bias
|
||||
|
||||
# Delete un-used branches
|
||||
for para in self.parameters():
|
||||
for name, para in self.named_parameters():
|
||||
if 'reparam_conv' in name:
|
||||
continue
|
||||
para.detach_()
|
||||
self.__delattr__('conv_kxk')
|
||||
self.__delattr__('conv_scale')
|
||||
if hasattr(self, 'identity'):
|
||||
self.__delattr__('identity')
|
||||
self.__delattr__('drop_path')
|
||||
|
||||
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Method to obtain re-parameterized kernel and bias.
|
||||
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
|
||||
|
||||
:return: Tuple of (kernel, bias) after fusing branches.
|
||||
"""
|
||||
# get weights and bias of scale branch
|
||||
kernel_scale = 0
|
||||
@ -671,7 +767,7 @@ class MobileOneBlock(nn.Module):
|
||||
if self.conv_scale is not None:
|
||||
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
|
||||
# Pad scale branch kernel to match conv branch kernel size.
|
||||
pad = self.kernel_size // 2
|
||||
pad = self.conv_kxk[0].conv.kernel_size[0] // 2
|
||||
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
||||
|
||||
# get weights and bias of skip branch
|
||||
@ -695,9 +791,6 @@ class MobileOneBlock(nn.Module):
|
||||
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Method to fuse batchnorm layer with preceeding conv layer.
|
||||
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
|
||||
|
||||
:param branch:
|
||||
:return: Tuple of (kernel, bias) after fusing batchnorm.
|
||||
"""
|
||||
if isinstance(branch, ConvNormAct):
|
||||
kernel = branch.conv.weight
|
||||
@ -709,16 +802,12 @@ class MobileOneBlock(nn.Module):
|
||||
else:
|
||||
assert isinstance(branch, nn.BatchNorm2d)
|
||||
if not hasattr(self, 'id_tensor'):
|
||||
input_dim = self.in_channels // self.groups
|
||||
kernel_value = torch.zeros(
|
||||
(self.in_channels,
|
||||
input_dim,
|
||||
self.kernel_size,
|
||||
self.kernel_size),
|
||||
dtype=branch.weight.dtype,
|
||||
device=branch.weight.device)
|
||||
for i in range(self.in_channels):
|
||||
kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
|
||||
in_chs = self.conv_kxk[0].conv.in_channels
|
||||
input_dim = in_chs // self.groups
|
||||
kernel_size = self.conv_kxk[0].conv.kernel_size
|
||||
kernel_value = torch.zeros_like(self.conv_kxk[0].conv.weight)
|
||||
for i in range(in_chs):
|
||||
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
|
||||
self.id_tensor = kernel_value
|
||||
kernel = self.id_tensor
|
||||
running_mean = branch.running_mean
|
||||
@ -1226,6 +1315,16 @@ model_cfgs = dict(
|
||||
num_features=1920,
|
||||
),
|
||||
|
||||
repvgg_a0=ByoModelCfg(
|
||||
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(0.75, 0.75, 0.75, 2.5)),
|
||||
stem_type='rep',
|
||||
stem_chs=48,
|
||||
),
|
||||
repvgg_a1=ByoModelCfg(
|
||||
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1, 1, 1, 2.5)),
|
||||
stem_type='rep',
|
||||
stem_chs=64,
|
||||
),
|
||||
repvgg_a2=ByoModelCfg(
|
||||
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
|
||||
stem_type='rep',
|
||||
@ -1643,15 +1742,11 @@ model_cfgs = dict(
|
||||
),
|
||||
)
|
||||
|
||||
# FIXME temporary for mobileone remap
|
||||
from .fastvit import checkpoint_filter_fn
|
||||
|
||||
|
||||
def _create_byobnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ByobNet, variant, pretrained,
|
||||
model_cfg=model_cfgs[variant],
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
@ -1683,6 +1778,12 @@ default_cfgs = generate_default_cfgs({
|
||||
'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
|
||||
# RepVGG weights
|
||||
'repvgg_a0.rvgg_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
||||
'repvgg_a1.rvgg_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
||||
'repvgg_a2.rvgg_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
||||
@ -1707,6 +1808,11 @@ default_cfgs = generate_default_cfgs({
|
||||
'repvgg_b3g4.rvgg_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
||||
'repvgg_d2se.rvgg_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit',
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
|
||||
),
|
||||
|
||||
# experimental ResNet configs
|
||||
'resnet51q.ra2_in1k': _cfg(
|
||||
@ -1810,24 +1916,24 @@ default_cfgs = generate_default_cfgs({
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
|
||||
|
||||
'mobileone_s0': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar',
|
||||
'mobileone_s0.apple_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.875,
|
||||
),
|
||||
'mobileone_s1': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar',
|
||||
'mobileone_s1.apple_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
'mobileone_s2': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar',
|
||||
'mobileone_s2.apple_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
'mobileone_s3': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar',
|
||||
'mobileone_s3.apple_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
'mobileone_s4': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar',
|
||||
'mobileone_s4.apple_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
})
|
||||
@ -1857,6 +1963,22 @@ def gernet_s(pretrained=False, **kwargs) -> ByobNet:
|
||||
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def repvgg_a0(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" RepVGG-A0
|
||||
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
||||
"""
|
||||
return _create_byobnet('repvgg_a0', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def repvgg_a1(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" RepVGG-A1
|
||||
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
||||
"""
|
||||
return _create_byobnet('repvgg_a1', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def repvgg_a2(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" RepVGG-A2
|
||||
|
@ -37,8 +37,8 @@ class ConvNorm(torch.nn.Sequential):
|
||||
b = bn.bias - bn.running_mean * bn.weight / \
|
||||
(bn.running_var + bn.eps)**0.5
|
||||
m = torch.nn.Conv2d(
|
||||
w.size(1) * self.c.groups, w.size(0), w.shape[2:],
|
||||
stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
||||
w.size(1) * self.conv.groups, w.size(0), w.shape[2:],
|
||||
stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
@ -1,9 +1,8 @@
|
||||
# FastViT for PyTorch
|
||||
#
|
||||
# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main
|
||||
#
|
||||
# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import List, Tuple, Optional, Union
|
||||
@ -12,8 +11,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn
|
||||
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
|
||||
ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
|
||||
@ -93,7 +94,7 @@ class MobileOneBlock(nn.Module):
|
||||
# Re-parameterizable skip connection
|
||||
self.reparam_conv = None
|
||||
|
||||
self.rbr_skip = (
|
||||
self.identity = (
|
||||
nn.BatchNorm2d(num_features=in_chs)
|
||||
if out_chs == in_chs and stride == 1
|
||||
else None
|
||||
@ -101,7 +102,7 @@ class MobileOneBlock(nn.Module):
|
||||
|
||||
# Re-parameterizable conv branches
|
||||
if num_conv_branches > 0:
|
||||
self.rbr_conv = nn.ModuleList([
|
||||
self.conv_kxk = nn.ModuleList([
|
||||
ConvNormAct(
|
||||
self.in_chs,
|
||||
self.out_chs,
|
||||
@ -112,12 +113,12 @@ class MobileOneBlock(nn.Module):
|
||||
) for _ in range(self.num_conv_branches)
|
||||
])
|
||||
else:
|
||||
self.rbr_conv = None
|
||||
self.conv_kxk = None
|
||||
|
||||
# Re-parameterizable scale branch
|
||||
self.rbr_scale = None
|
||||
self.conv_scale = None
|
||||
if kernel_size > 1 and use_scale_branch:
|
||||
self.rbr_scale = ConvNormAct(
|
||||
self.conv_scale = ConvNormAct(
|
||||
self.in_chs,
|
||||
self.out_chs,
|
||||
kernel_size=1,
|
||||
@ -135,20 +136,20 @@ class MobileOneBlock(nn.Module):
|
||||
return self.act(self.se(self.reparam_conv(x)))
|
||||
|
||||
# Multi-branched train-time forward pass.
|
||||
# Skip branch output
|
||||
# Identity branch output
|
||||
identity_out = 0
|
||||
if self.rbr_skip is not None:
|
||||
identity_out = self.rbr_skip(x)
|
||||
if self.identity is not None:
|
||||
identity_out = self.identity(x)
|
||||
|
||||
# Scale branch output
|
||||
scale_out = 0
|
||||
if self.rbr_scale is not None:
|
||||
scale_out = self.rbr_scale(x)
|
||||
if self.conv_scale is not None:
|
||||
scale_out = self.conv_scale(x)
|
||||
|
||||
# Other branches
|
||||
# Other kxk conv branches
|
||||
out = scale_out + identity_out
|
||||
if self.rbr_conv is not None:
|
||||
for rc in self.rbr_conv:
|
||||
if self.conv_kxk is not None:
|
||||
for rc in self.conv_kxk:
|
||||
out += rc(x)
|
||||
|
||||
return self.act(self.se(out))
|
||||
@ -176,13 +177,15 @@ class MobileOneBlock(nn.Module):
|
||||
self.reparam_conv.bias.data = bias
|
||||
|
||||
# Delete un-used branches
|
||||
for para in self.parameters():
|
||||
for name, para in self.named_parameters():
|
||||
if 'reparam_conv' in name:
|
||||
continue
|
||||
para.detach_()
|
||||
|
||||
self.__delattr__("rbr_conv")
|
||||
self.__delattr__("rbr_scale")
|
||||
if hasattr(self, "rbr_skip"):
|
||||
self.__delattr__("rbr_skip")
|
||||
self.__delattr__("conv_kxk")
|
||||
self.__delattr__("conv_scale")
|
||||
if hasattr(self, "identity"):
|
||||
self.__delattr__("identity")
|
||||
|
||||
self.inference_mode = True
|
||||
|
||||
@ -196,8 +199,8 @@ class MobileOneBlock(nn.Module):
|
||||
# get weights and bias of scale branch
|
||||
kernel_scale = 0
|
||||
bias_scale = 0
|
||||
if self.rbr_scale is not None:
|
||||
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
|
||||
if self.conv_scale is not None:
|
||||
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
|
||||
# Pad scale branch kernel to match conv branch kernel size.
|
||||
pad = self.kernel_size // 2
|
||||
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
||||
@ -205,15 +208,15 @@ class MobileOneBlock(nn.Module):
|
||||
# get weights and bias of skip branch
|
||||
kernel_identity = 0
|
||||
bias_identity = 0
|
||||
if self.rbr_skip is not None:
|
||||
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
|
||||
if self.identity is not None:
|
||||
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
|
||||
|
||||
# get weights and bias of conv branches
|
||||
kernel_conv = 0
|
||||
bias_conv = 0
|
||||
if self.rbr_conv is not None:
|
||||
if self.conv_kxk is not None:
|
||||
for ix in range(self.num_conv_branches):
|
||||
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
|
||||
_kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
|
||||
kernel_conv += _kernel
|
||||
bias_conv += _bias
|
||||
|
||||
@ -233,7 +236,7 @@ class MobileOneBlock(nn.Module):
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing batchnorm.
|
||||
"""
|
||||
if isinstance(branch, nn.Sequential):
|
||||
if isinstance(branch, ConvNormAct):
|
||||
kernel = branch.conv.weight
|
||||
running_mean = branch.bn.running_mean
|
||||
running_var = branch.bn.running_var
|
||||
@ -306,7 +309,7 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.small_kernel = small_kernel
|
||||
if inference_mode:
|
||||
self.lkb_reparam = create_conv2d(
|
||||
self.reparam_conv = create_conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
@ -316,8 +319,8 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.lkb_reparam = None
|
||||
self.lkb_origin = ConvNormAct(
|
||||
self.reparam_conv = None
|
||||
self.large_conv = ConvNormAct(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
@ -341,10 +344,10 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
self.act = act_layer() if act_layer is not None else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.lkb_reparam is not None:
|
||||
out = self.lkb_reparam(x)
|
||||
if self.reparam_conv is not None:
|
||||
out = self.reparam_conv(x)
|
||||
else:
|
||||
out = self.lkb_origin(x)
|
||||
out = self.large_conv(x)
|
||||
if self.small_conv is not None:
|
||||
out = out + self.small_conv(x)
|
||||
out = self.act(out)
|
||||
@ -357,7 +360,7 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing branches.
|
||||
"""
|
||||
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
|
||||
eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn)
|
||||
if hasattr(self, "small_conv"):
|
||||
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
|
||||
eq_b += small_b
|
||||
@ -374,19 +377,18 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
for inference.
|
||||
"""
|
||||
eq_k, eq_b = self.get_kernel_bias()
|
||||
self.lkb_reparam = create_conv2d(
|
||||
self.reparam_conv = create_conv2d(
|
||||
self.in_chs,
|
||||
self.out_chs,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=self.lkb_origin.conv.dilation,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.lkb_reparam.weight.data = eq_k
|
||||
self.lkb_reparam.bias.data = eq_b
|
||||
self.__delattr__("lkb_origin")
|
||||
self.reparam_conv.weight.data = eq_k
|
||||
self.reparam_conv.bias.data = eq_b
|
||||
self.__delattr__("large_conv")
|
||||
if hasattr(self, "small_conv"):
|
||||
self.__delattr__("small_conv")
|
||||
|
||||
@ -532,6 +534,8 @@ class PatchEmbed(nn.Module):
|
||||
stride: int,
|
||||
in_chs: int,
|
||||
embed_dim: int,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
"""Build patch embedding layer.
|
||||
@ -553,13 +557,14 @@ class PatchEmbed(nn.Module):
|
||||
group_size=1,
|
||||
small_kernel=3,
|
||||
inference_mode=inference_mode,
|
||||
act_layer=None, # activation was not used in original impl
|
||||
act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_chs=embed_dim,
|
||||
out_chs=embed_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act_layer=act_layer,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
)
|
||||
@ -569,6 +574,16 @@ class PatchEmbed(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale2d(nn.Module):
|
||||
def __init__(self, dim, init_values=1e-5, inplace=False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
||||
|
||||
class RepMixer(nn.Module):
|
||||
"""Reparameterizable token mixer.
|
||||
|
||||
@ -580,7 +595,6 @@ class RepMixer(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
@ -589,7 +603,6 @@ class RepMixer(nn.Module):
|
||||
Args:
|
||||
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
||||
kernel_size: Kernel size for spatial mixing. Default: 3
|
||||
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
||||
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
"""
|
||||
@ -626,19 +639,16 @@ class RepMixer(nn.Module):
|
||||
group_size=1,
|
||||
use_act=False,
|
||||
)
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale = nn.Identity
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.reparam_conv is not None:
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
if self.use_layer_scale:
|
||||
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
||||
else:
|
||||
x = x + self.mixer(x) - self.norm(x)
|
||||
x = x + self.layer_scale(self.mixer(x) - self.norm(x))
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
@ -651,11 +661,11 @@ class RepMixer(nn.Module):
|
||||
self.mixer.reparameterize()
|
||||
self.norm.reparameterize()
|
||||
|
||||
if self.use_layer_scale:
|
||||
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
||||
if isinstance(self.layer_scale, LayerScale2d):
|
||||
w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * (
|
||||
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
)
|
||||
b = torch.squeeze(self.layer_scale) * (
|
||||
b = torch.squeeze(self.layer_scale.gamma) * (
|
||||
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||||
)
|
||||
else:
|
||||
@ -677,11 +687,12 @@ class RepMixer(nn.Module):
|
||||
self.reparam_conv.weight.data = w
|
||||
self.reparam_conv.bias.data = b
|
||||
|
||||
for para in self.parameters():
|
||||
for name, para in self.named_parameters():
|
||||
if 'reparam_conv' in name:
|
||||
continue
|
||||
para.detach_()
|
||||
self.__delattr__("mixer")
|
||||
self.__delattr__("norm")
|
||||
if self.use_layer_scale:
|
||||
self.__delattr__("layer_scale")
|
||||
|
||||
|
||||
@ -708,19 +719,6 @@ class ConvMlp(nn.Module):
|
||||
super().__init__()
|
||||
out_chs = out_chs or in_chs
|
||||
hidden_channels = hidden_channels or in_chs
|
||||
# self.conv = nn.Sequential()
|
||||
# self.conv.add_module(
|
||||
# "conv",
|
||||
# nn.Conv2d(
|
||||
# in_chs,
|
||||
# out_chs,
|
||||
# kernel_size=7,
|
||||
# padding=3,
|
||||
# groups=in_chs,
|
||||
# bias=False,
|
||||
# ),
|
||||
# )
|
||||
# self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
|
||||
self.conv = ConvNormAct(
|
||||
in_chs,
|
||||
out_chs,
|
||||
@ -750,7 +748,7 @@ class ConvMlp(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class RepCPE(nn.Module):
|
||||
class RepConditionalPosEnc(nn.Module):
|
||||
"""Implementation of conditional positional encoding.
|
||||
|
||||
For more details refer to paper:
|
||||
@ -774,7 +772,7 @@ class RepCPE(nn.Module):
|
||||
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super(RepCPE, self).__init__()
|
||||
super(RepConditionalPosEnc, self).__init__()
|
||||
if isinstance(spatial_shape, int):
|
||||
spatial_shape = tuple([spatial_shape] * 2)
|
||||
assert isinstance(spatial_shape, Tuple), (
|
||||
@ -803,7 +801,7 @@ class RepCPE(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.reparam_conv = None
|
||||
self.pe = nn.Conv2d(
|
||||
self.pos_enc = nn.Conv2d(
|
||||
self.dim,
|
||||
self.dim_out,
|
||||
spatial_shape,
|
||||
@ -816,9 +814,8 @@ class RepCPE(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.reparam_conv is not None:
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
x = self.pe(x) + x
|
||||
x = self.pos_enc(x) + x
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
@ -831,8 +828,8 @@ class RepCPE(nn.Module):
|
||||
self.spatial_shape[0],
|
||||
self.spatial_shape[1],
|
||||
),
|
||||
dtype=self.pe.weight.dtype,
|
||||
device=self.pe.weight.device,
|
||||
dtype=self.pos_enc.weight.dtype,
|
||||
device=self.pos_enc.weight.device,
|
||||
)
|
||||
for i in range(self.dim):
|
||||
kernel_value[
|
||||
@ -844,8 +841,8 @@ class RepCPE(nn.Module):
|
||||
id_tensor = kernel_value
|
||||
|
||||
# Reparameterize Id tensor and conv
|
||||
w_final = id_tensor + self.pe.weight
|
||||
b_final = self.pe.bias
|
||||
w_final = id_tensor + self.pos_enc.weight
|
||||
b_final = self.pos_enc.bias
|
||||
|
||||
# Introduce reparam conv
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
@ -860,9 +857,11 @@ class RepCPE(nn.Module):
|
||||
self.reparam_conv.weight.data = w_final
|
||||
self.reparam_conv.bias.data = b_final
|
||||
|
||||
for para in self.parameters():
|
||||
for name, para in self.named_parameters():
|
||||
if 'reparam_conv' in name:
|
||||
continue
|
||||
para.detach_()
|
||||
self.__delattr__("pe")
|
||||
self.__delattr__("pos_enc")
|
||||
|
||||
|
||||
class RepMixerBlock(nn.Module):
|
||||
@ -878,9 +877,8 @@ class RepMixerBlock(nn.Module):
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
@ -891,9 +889,8 @@ class RepMixerBlock(nn.Module):
|
||||
kernel_size: Kernel size for repmixer. Default: 3
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
proj_drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
@ -903,36 +900,25 @@ class RepMixerBlock(nn.Module):
|
||||
self.token_mixer = RepMixer(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
|
||||
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
||||
mlp_ratio
|
||||
)
|
||||
self.convffn = ConvMlp(
|
||||
self.mlp = ConvMlp(
|
||||
in_chs=dim,
|
||||
hidden_channels=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
# Drop Path
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale = nn.Identity()
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
||||
else:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
x = x + self.drop_path(self.layer_scale(self.mlp(x)))
|
||||
return x
|
||||
|
||||
|
||||
@ -949,9 +935,8 @@ class AttentionBlock(nn.Module):
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
):
|
||||
"""Build Attention Block.
|
||||
@ -961,9 +946,8 @@ class AttentionBlock(nn.Module):
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
proj_drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
"""
|
||||
|
||||
@ -971,34 +955,27 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
self.norm = norm_layer(dim)
|
||||
self.token_mixer = Attention(dim=dim)
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale_1 = nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
||||
mlp_ratio
|
||||
)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvMlp(
|
||||
self.mlp = ConvMlp(
|
||||
in_chs=dim,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
hidden_channels=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
# Drop path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale_2 = nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
|
||||
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
|
||||
else:
|
||||
x = x + self.drop_path(self.token_mixer(self.norm(x)))
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x))))
|
||||
x = x + self.drop_path2(self.layer_scale_2(self.mlp(x)))
|
||||
return x
|
||||
|
||||
|
||||
@ -1017,35 +994,38 @@ class FastVitStage(nn.Module):
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop_rate: float = 0.0,
|
||||
proj_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
layer_scale_init_value: Optional[float] = 1e-5,
|
||||
lkc_use_act=False,
|
||||
inference_mode=False,
|
||||
):
|
||||
"""FastViT stage.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
num_blocks: List containing number of blocks per stage.
|
||||
depth: Number of blocks in stage
|
||||
token_mixer_type: Token mixer type.
|
||||
kernel_size: Kernel size for repmixer.
|
||||
mlp_ratio: MLP expansion ratio.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
drop_rate: Dropout rate.
|
||||
proj_drop_rate: Dropout rate.
|
||||
drop_path_rate: Drop path rate.
|
||||
use_layer_scale: Flag to turn on layer scale regularization.
|
||||
layer_scale_init_value: Layer scale value at initialization.
|
||||
inference_mode: Flag to instantiate block in inference mode.
|
||||
"""
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
||||
if downsample:
|
||||
self.downsample = PatchEmbed(
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
in_chs=dim,
|
||||
embed_dim=dim_out,
|
||||
act_layer=act_layer,
|
||||
lkc_use_act=lkc_use_act,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
else:
|
||||
@ -1065,9 +1045,8 @@ class FastVitStage(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
drop_path=drop_path_rate[block_idx],
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
))
|
||||
@ -1077,9 +1056,8 @@ class FastVitStage(nn.Module):
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
drop_path=drop_path_rate[block_idx],
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
else:
|
||||
@ -1091,6 +1069,9 @@ class FastVitStage(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.pos_emb(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
@ -1116,21 +1097,25 @@ class FastVit(nn.Module):
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
drop_rate: float = 0.0,
|
||||
proj_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
fork_feat: bool = False,
|
||||
cls_ratio: float = 2.0,
|
||||
global_pool: str = 'avg',
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_classes = 0 if fork_feat else num_classes
|
||||
self.fork_feat = fork_feat
|
||||
self.global_pool = global_pool
|
||||
self.feature_info = []
|
||||
|
||||
# Convolutional stem
|
||||
self.patch_embed = convolutional_stem(
|
||||
self.stem = convolutional_stem(
|
||||
in_chans,
|
||||
embed_dims[0],
|
||||
inference_mode,
|
||||
@ -1138,14 +1123,16 @@ class FastVit(nn.Module):
|
||||
|
||||
# Build the main stages of the network architecture
|
||||
prev_dim = embed_dims[0]
|
||||
scale = 1
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
||||
network = []
|
||||
stages = []
|
||||
for i in range(len(layers)):
|
||||
downsample = downsamples[i] or prev_dim != embed_dims[i]
|
||||
stage = FastVitStage(
|
||||
dim=prev_dim,
|
||||
dim_out=embed_dims[i],
|
||||
depth=layers[i],
|
||||
downsample=downsamples[i] or prev_dim != embed_dims[i],
|
||||
downsample=downsample,
|
||||
down_patch_size=down_patch_size,
|
||||
down_stride=down_stride,
|
||||
pos_emb_layer=pos_embs[i],
|
||||
@ -1154,16 +1141,19 @@ class FastVit(nn.Module):
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
proj_drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
lkc_use_act=lkc_use_act,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
network.append(stage)
|
||||
stages.append(stage)
|
||||
prev_dim = embed_dims[i]
|
||||
|
||||
self.network = nn.Sequential(*network)
|
||||
if downsample:
|
||||
scale *= 2
|
||||
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.num_features = prev_dim
|
||||
|
||||
# For segmentation and detection, extract intermediate output
|
||||
if self.fork_feat:
|
||||
@ -1181,10 +1171,10 @@ class FastVit(nn.Module):
|
||||
self.add_module(layer_name, layer)
|
||||
else:
|
||||
# Classifier head
|
||||
self.gap = nn.AdaptiveAvgPool2d(output_size=1)
|
||||
self.conv_exp = MobileOneBlock(
|
||||
final_features = int(embed_dims[-1] * cls_ratio)
|
||||
self.final_conv = MobileOneBlock(
|
||||
in_chs=embed_dims[-1],
|
||||
out_chs=int(embed_dims[-1] * cls_ratio),
|
||||
out_chs=final_features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
group_size=1,
|
||||
@ -1192,10 +1182,12 @@ class FastVit(nn.Module):
|
||||
use_se=True,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
self.head = (
|
||||
nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
|
||||
if num_classes > 0
|
||||
else nn.Identity()
|
||||
self.num_features = final_features
|
||||
self.head = ClassifierHead(
|
||||
final_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
@ -1207,23 +1199,39 @@ class FastVit(nn.Module):
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
@staticmethod
|
||||
def _scrub_checkpoint(checkpoint, model):
|
||||
sterile_dict = {}
|
||||
for k1, v1 in checkpoint.items():
|
||||
if k1 not in model.state_dict():
|
||||
continue
|
||||
if v1.shape == model.state_dict()[k1].shape:
|
||||
sterile_dict[k1] = v1
|
||||
return sterile_dict
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return set()
|
||||
|
||||
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
return x
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^stem', # stem and embed
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+).pos_emb', (0,)),
|
||||
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
|
||||
]
|
||||
)
|
||||
|
||||
def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
for s in self.stages:
|
||||
s.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# input embedding
|
||||
x = self.stem(x)
|
||||
outs = []
|
||||
for idx, block in enumerate(self.network):
|
||||
for idx, block in enumerate(self.stages):
|
||||
x = block(x)
|
||||
if self.fork_feat:
|
||||
if idx in self.out_indices:
|
||||
@ -1236,20 +1244,16 @@ class FastVit(nn.Module):
|
||||
# output only the features of last layer for image classification
|
||||
return x
|
||||
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
|
||||
x = self.final_conv(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# input embedding
|
||||
x = self.forward_embeddings(x)
|
||||
# through backbone
|
||||
x = self.forward_tokens(x)
|
||||
x = self.forward_features(x)
|
||||
if self.fork_feat:
|
||||
# output features of four stages for dense prediction
|
||||
return x
|
||||
# for image classification
|
||||
x = self.conv_exp(x)
|
||||
x = self.gap(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
cls_out = self.head(x)
|
||||
return cls_out
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
@ -1257,81 +1261,63 @@ def _cfg(url="", **kwargs):
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 256, 256),
|
||||
"pool_size": None,
|
||||
"pool_size": (8, 8),
|
||||
"crop_pct": 0.9,
|
||||
"interpolation": "bicubic",
|
||||
"mean": IMAGENET_DEFAULT_MEAN,
|
||||
"std": IMAGENET_DEFAULT_STD,
|
||||
"classifier": "head",
|
||||
"classifier": "head.fc",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
"fastvit_t8.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar'
|
||||
),
|
||||
hf_hub_id='timm/'),
|
||||
"fastvit_t12.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t12.pth.tar'
|
||||
),
|
||||
hf_hub_id='timm/'),
|
||||
|
||||
"fastvit_s12.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_s12.pth.tar'),
|
||||
hf_hub_id='timm/'),
|
||||
"fastvit_sa12.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa12.pth.tar'),
|
||||
hf_hub_id='timm/'),
|
||||
"fastvit_sa24.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa24.pth.tar'),
|
||||
hf_hub_id='timm/'),
|
||||
"fastvit_sa36.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa36.pth.tar'),
|
||||
hf_hub_id='timm/'),
|
||||
|
||||
"fastvit_ma36.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_ma36.pth.tar',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95
|
||||
),
|
||||
|
||||
# "fastvit_t8.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t8.pth.tar'
|
||||
# ),
|
||||
# "fastvit_t12.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t12.pth.tar'
|
||||
# ),
|
||||
#
|
||||
# "fastvit_s12.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12.pth.tar'),
|
||||
# "fastvit_sa12.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa12.pth.tar'),
|
||||
# "fastvit_sa24.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa24.pth.tar'),
|
||||
# "fastvit_sa36.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa36.pth.tar'),
|
||||
#
|
||||
# "fastvit_ma36.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_ma36.pth.tar',
|
||||
# crop_pct=0.95
|
||||
# ),
|
||||
"fastvit_t8.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"fastvit_t12.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
|
||||
"fastvit_s12.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/',),
|
||||
"fastvit_sa12.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/',),
|
||||
"fastvit_sa24.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/',),
|
||||
"fastvit_sa36.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/',),
|
||||
|
||||
"fastvit_ma36.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
# FIXME temporary for remapping
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
msd = model.state_dict()
|
||||
out_dict = {}
|
||||
for ka, kb, va, vb in zip(msd.keys(), state_dict.keys(), msd.values(), state_dict.values()):
|
||||
if va.ndim == 4 and vb.ndim == 2:
|
||||
vb = vb[:, :, None, None]
|
||||
out_dict[ka] = vb
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_fastvit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
|
||||
model = build_model_with_cfg(
|
||||
FastVit,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
@ -1381,7 +1367,7 @@ def fastvit_sa12(pretrained=False, **kwargs):
|
||||
layers=(2, 2, 6, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
@ -1394,7 +1380,7 @@ def fastvit_sa24(pretrained=False, **kwargs):
|
||||
layers=(4, 4, 12, 4),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
@ -1407,11 +1393,12 @@ def fastvit_sa36(pretrained=False, **kwargs):
|
||||
layers=(6, 6, 18, 6),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_ma36(pretrained=False, **kwargs):
|
||||
"""Instantiate FastViT-MA36 model variant."""
|
||||
@ -1419,7 +1406,7 @@ def fastvit_ma36(pretrained=False, **kwargs):
|
||||
layers=(6, 6, 18, 6),
|
||||
embed_dims=(76, 152, 304, 608),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
|
||||
)
|
||||
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
|
@ -126,6 +126,9 @@ class RepGhostModule(nn.Module):
|
||||
self.fusion_conv = []
|
||||
self.fusion_bn = []
|
||||
|
||||
def reparameterize(self):
|
||||
self.switch_to_deploy()
|
||||
|
||||
|
||||
class RepGhostBottleneck(nn.Module):
|
||||
""" RepGhost bottleneck w/ optional SE"""
|
||||
|
@ -9,7 +9,7 @@ from .jit import set_jit_legacy, set_jit_fuser
|
||||
from .log import setup_default_logging, FormatterNoInfo
|
||||
from .metrics import AverageMeter, accuracy
|
||||
from .misc import natural_key, add_bool_arg, ParseKwargs
|
||||
from .model import unwrap_model, get_state_dict, freeze, unfreeze
|
||||
from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
|
||||
from .model_ema import ModelEma, ModelEmaV2
|
||||
from .random import random_seed
|
||||
from .summary import update_summary, get_outdir
|
||||
|
@ -3,6 +3,7 @@
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import fnmatch
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
@ -219,3 +220,21 @@ def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
|
||||
See example in docstring for `freeze`.
|
||||
"""
|
||||
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
|
||||
|
||||
|
||||
def reparameterize_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
|
||||
if not inplace:
|
||||
model = deepcopy(model)
|
||||
|
||||
def _fuse(m):
|
||||
for child_name, child in m.named_children():
|
||||
if hasattr(child, 'fuse'):
|
||||
setattr(m, child_name, child.fuse())
|
||||
elif hasattr(child, "reparameterize"):
|
||||
child.reparameterize()
|
||||
elif hasattr(child, "switch_to_deploy"):
|
||||
child.switch_to_deploy()
|
||||
_fuse(child)
|
||||
|
||||
_fuse(model)
|
||||
return model
|
||||
|
@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
|
||||
from timm.layers import apply_test_time_pool, set_fast_norm
|
||||
from timm.models import create_model, load_checkpoint, is_model, list_models
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
|
||||
decay_batch_step, check_batch_size_retry, ParseKwargs
|
||||
decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
@ -125,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
parser.add_argument('--reparam', default=False, action='store_true',
|
||||
help='Reparameterize model')
|
||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||
|
||||
|
||||
@ -207,6 +209,9 @@ def validate(args):
|
||||
if args.checkpoint:
|
||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||
|
||||
if args.reparam:
|
||||
model = reparameterize_model(model)
|
||||
|
||||
param_count = sum([m.numel() for m in model.parameters()])
|
||||
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user