MobileOne and FastViT weights on HF hub, more code cleanup and tweaks, features_only working. Add reparam flag to validate and benchmark, support reparm of all models with fuse(), reparameterize() or switch_to_deploy() methods on modules

This commit is contained in:
Ross Wightman 2023-08-23 14:16:43 -07:00 committed by Ross Wightman
parent 40dbaafef5
commit 5242ba6edc
8 changed files with 447 additions and 304 deletions

View File

@ -22,7 +22,8 @@ from timm.data import resolve_data_config
from timm.layers import set_fast_norm from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
reparameterize_model
has_apex = False has_apex = False
try: try:
@ -116,6 +117,8 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
parser.add_argument('--reparam', default=False, action='store_true',
help='Reparameterize model')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
# codegen (model compilation) options # codegen (model compilation) options
@ -222,6 +225,7 @@ class BenchmarkRunner:
torchscript=False, torchscript=False,
torchcompile=None, torchcompile=None,
aot_autograd=False, aot_autograd=False,
reparam=False,
precision='float32', precision='float32',
fuser='', fuser='',
num_warm_iter=10, num_warm_iter=10,
@ -252,10 +256,13 @@ class BenchmarkRunner:
drop_block_rate=kwargs.pop('drop_block', None), drop_block_rate=kwargs.pop('drop_block', None),
**kwargs.pop('model_kwargs', {}), **kwargs.pop('model_kwargs', {}),
) )
if reparam:
self.model = reparameterize_model(self.model)
self.model.to( self.model.to(
device=self.device, device=self.device,
dtype=self.model_dtype, dtype=self.model_dtype,
memory_format=torch.channels_last if self.channels_last else None) memory_format=torch.channels_last if self.channels_last else None,
)
self.num_classes = self.model.num_classes self.num_classes = self.model.num_classes
self.param_count = count_params(self.model) self.param_count = count_params(self.model)
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count))

View File

@ -12,6 +12,10 @@ RepVGG - repvgg_*
Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT
MobileOne - mobileone_*
Paper: `MobileOne: An Improved One millisecond Mobile Backbone` - https://arxiv.org/abs/2206.04040
Code and weights: https://github.com/apple/ml-mobileone, licensed MIT
In all cases the models have been modified to fit within the design of ByobNet. I've remapped In all cases the models have been modified to fit within the design of ByobNet. I've remapped
the original weights and verified accuracies. the original weights and verified accuracies.
@ -468,8 +472,6 @@ class RepVggBlock(nn.Module):
""" RepVGG Block. """ RepVGG Block.
Adapted from impl at https://github.com/DingXiaoH/RepVGG Adapted from impl at https://github.com/DingXiaoH/RepVGG
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
""" """
def __init__( def __init__(
@ -485,11 +487,24 @@ class RepVggBlock(nn.Module):
layers: LayerFn = None, layers: LayerFn = None,
drop_block: Callable = None, drop_block: Callable = None,
drop_path_rate: float = 0., drop_path_rate: float = 0.,
inference_mode: bool = False
): ):
super(RepVggBlock, self).__init__() super(RepVggBlock, self).__init__()
self.groups = groups = num_groups(group_size, in_chs)
layers = layers or LayerFn() layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
#self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME temp for remapping if inference_mode:
self.reparam_conv = nn.Conv2d(
in_channels=in_chs,
out_channels=out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=True,
)
else:
self.reparam_conv = None
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
self.conv_kxk = layers.conv_norm_act( self.conv_kxk = layers.conv_norm_act(
@ -497,8 +512,9 @@ class RepVggBlock(nn.Module):
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False, stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
) )
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.act = layers.act(inplace=True) self.act = layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False): def init_weights(self, zero_init_last: bool = False):
@ -511,16 +527,109 @@ class RepVggBlock(nn.Module):
self.attn.reset_parameters() self.attn.reset_parameters()
def forward(self, x): def forward(self, x):
if self.reparam_conv is not None:
return self.act(self.attn(self.reparam_conv(x)))
if self.identity is None: if self.identity is None:
x = self.conv_1x1(x) + self.conv_kxk(x) x = self.conv_1x1(x) + self.conv_kxk(x)
else: else:
identity = self.identity(x) identity = self.identity(x)
x = self.conv_1x1(x) + self.conv_kxk(x) x = self.conv_1x1(x) + self.conv_kxk(x)
x = self.drop_path(x) # not in the paper / official impl, experimental x = self.drop_path(x) # not in the paper / official impl, experimental
x = x + identity x += identity
x = self.attn(x) # no attn in the paper / official impl, experimental x = self.attn(x) # no attn in the paper / official impl, experimental
return self.act(x) return self.act(x)
def reparameterize(self):
""" Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
if self.reparam_conv is not None:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2d(
in_channels=self.conv_kxk.conv.in_channels,
out_channels=self.conv_kxk.conv.out_channels,
kernel_size=self.conv_kxk.conv.kernel_size,
stride=self.conv_kxk.conv.stride,
padding=self.conv_kxk.conv.padding,
dilation=self.conv_kxk.conv.dilation,
groups=self.conv_kxk.conv.groups,
bias=True,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
# Delete un-used branches
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__('conv_kxk')
self.__delattr__('conv_1x1')
self.__delattr__('identity')
self.__delattr__('drop_path')
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
""" Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
"""
# get weights and bias of scale branch
kernel_1x1 = 0
bias_1x1 = 0
if self.conv_1x1 is not None:
kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1)
# Pad scale branch kernel to match conv branch kernel size.
pad = self.conv_kxk.conv.kernel_size[0] // 2
kernel_1x1 = torch.nn.functional.pad(kernel_1x1, [pad, pad, pad, pad])
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.identity is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
# get weights and bias of conv branches
kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk)
kernel_final = kernel_conv + kernel_1x1 + kernel_identity
bias_final = bias_conv + bias_1x1 + bias_identity
return kernel_final, bias_final
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
""" Method to fuse batchnorm layer with preceeding conv layer.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
"""
if isinstance(branch, ConvNormAct):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
in_chs = self.conv_kxk.conv.in_channels
input_dim = in_chs // self.groups
kernel_size = self.conv_kxk.conv.kernel_size
kernel_value = torch.zeros_like(self.conv_kxk.conv.weight)
for i in range(in_chs):
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class MobileOneBlock(nn.Module): class MobileOneBlock(nn.Module):
""" MobileOne building block. """ MobileOne building block.
@ -549,28 +658,11 @@ class MobileOneBlock(nn.Module):
drop_path_rate: float = 0., drop_path_rate: float = 0.,
) -> None: ) -> None:
""" Construct a MobileOneBlock module. """ Construct a MobileOneBlock module.
:param in_chs: Number of channels in the input.
:param out_chs: Number of channels produced by the block.
:param kernel_size: Size of the convolution kernel.
:param stride: Stride size.
:param dilation: Kernel dilation factor.
:param groups: Group number.
:param inference_mode: If True, instantiates model in inference mode.
:param use_se: Whether to use SE-ReLU activations.
:param num_conv_branches: Number of linear conv branches.
""" """
super(MobileOneBlock, self).__init__() super(MobileOneBlock, self).__init__()
self.stride = stride
self.kernel_size = kernel_size
self.in_channels = in_chs
self.out_channels = out_chs
self.num_conv_branches = num_conv_branches self.num_conv_branches = num_conv_branches
self.groups = groups = num_groups(group_size, in_chs)
layers = layers or LayerFn() layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
# Check if SE-ReLU is requested
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME move after remap
if inference_mode: if inference_mode:
self.reparam_conv = nn.Conv2d( self.reparam_conv = nn.Conv2d(
@ -602,7 +694,9 @@ class MobileOneBlock(nn.Module):
self.conv_scale = layers.conv_norm_act( self.conv_scale = layers.conv_norm_act(
in_chs, out_chs, kernel_size=1, in_chs, out_chs, kernel_size=1,
stride=stride, groups=groups, apply_act=False) stride=stride, groups=groups, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.act = layers.act(inplace=True) self.act = layers.act(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -623,9 +717,11 @@ class MobileOneBlock(nn.Module):
scale_out = self.conv_scale(x) scale_out = self.conv_scale(x)
# Other branches # Other branches
out = scale_out + identity_out out = scale_out
for ck in self.conv_kxk: for ck in self.conv_kxk:
out += ck(x) out += ck(x)
out = self.drop_path(out)
out += identity_out
return self.act(self.attn(out)) return self.act(self.attn(out))
@ -652,18 +748,18 @@ class MobileOneBlock(nn.Module):
self.reparam_conv.bias.data = bias self.reparam_conv.bias.data = bias
# Delete un-used branches # Delete un-used branches
for para in self.parameters(): for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_() para.detach_()
self.__delattr__('conv_kxk') self.__delattr__('conv_kxk')
self.__delattr__('conv_scale') self.__delattr__('conv_scale')
if hasattr(self, 'identity'):
self.__delattr__('identity') self.__delattr__('identity')
self.__delattr__('drop_path')
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
""" Method to obtain re-parameterized kernel and bias. """ Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
:return: Tuple of (kernel, bias) after fusing branches.
""" """
# get weights and bias of scale branch # get weights and bias of scale branch
kernel_scale = 0 kernel_scale = 0
@ -671,7 +767,7 @@ class MobileOneBlock(nn.Module):
if self.conv_scale is not None: if self.conv_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale) kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
# Pad scale branch kernel to match conv branch kernel size. # Pad scale branch kernel to match conv branch kernel size.
pad = self.kernel_size // 2 pad = self.conv_kxk[0].conv.kernel_size[0] // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
# get weights and bias of skip branch # get weights and bias of skip branch
@ -695,9 +791,6 @@ class MobileOneBlock(nn.Module):
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
""" Method to fuse batchnorm layer with preceeding conv layer. """ Method to fuse batchnorm layer with preceeding conv layer.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
:param branch:
:return: Tuple of (kernel, bias) after fusing batchnorm.
""" """
if isinstance(branch, ConvNormAct): if isinstance(branch, ConvNormAct):
kernel = branch.conv.weight kernel = branch.conv.weight
@ -709,16 +802,12 @@ class MobileOneBlock(nn.Module):
else: else:
assert isinstance(branch, nn.BatchNorm2d) assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'): if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups in_chs = self.conv_kxk[0].conv.in_channels
kernel_value = torch.zeros( input_dim = in_chs // self.groups
(self.in_channels, kernel_size = self.conv_kxk[0].conv.kernel_size
input_dim, kernel_value = torch.zeros_like(self.conv_kxk[0].conv.weight)
self.kernel_size, for i in range(in_chs):
self.kernel_size), kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
dtype=branch.weight.dtype,
device=branch.weight.device)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
self.id_tensor = kernel_value self.id_tensor = kernel_value
kernel = self.id_tensor kernel = self.id_tensor
running_mean = branch.running_mean running_mean = branch.running_mean
@ -1226,6 +1315,16 @@ model_cfgs = dict(
num_features=1920, num_features=1920,
), ),
repvgg_a0=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(0.75, 0.75, 0.75, 2.5)),
stem_type='rep',
stem_chs=48,
),
repvgg_a1=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1, 1, 1, 2.5)),
stem_type='rep',
stem_chs=64,
),
repvgg_a2=ByoModelCfg( repvgg_a2=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)), blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
stem_type='rep', stem_type='rep',
@ -1643,15 +1742,11 @@ model_cfgs = dict(
), ),
) )
# FIXME temporary for mobileone remap
from .fastvit import checkpoint_filter_fn
def _create_byobnet(variant, pretrained=False, **kwargs): def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
ByobNet, variant, pretrained, ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant], model_cfg=model_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**kwargs) **kwargs)
@ -1683,6 +1778,12 @@ default_cfgs = generate_default_cfgs({
'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)), 'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)),
# RepVGG weights # RepVGG weights
'repvgg_a0.rvgg_in1k': _cfg(
hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
'repvgg_a1.rvgg_in1k': _cfg(
hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
'repvgg_a2.rvgg_in1k': _cfg( 'repvgg_a2.rvgg_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
@ -1707,6 +1808,11 @@ default_cfgs = generate_default_cfgs({
'repvgg_b3g4.rvgg_in1k': _cfg( 'repvgg_b3g4.rvgg_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
'repvgg_d2se.rvgg_in1k': _cfg(
hf_hub_id='timm/',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit',
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
),
# experimental ResNet configs # experimental ResNet configs
'resnet51q.ra2_in1k': _cfg( 'resnet51q.ra2_in1k': _cfg(
@ -1810,24 +1916,24 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
'mobileone_s0': _cfg( 'mobileone_s0.apple_in1k': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar', hf_hub_id='timm/',
crop_pct=0.875, crop_pct=0.875,
), ),
'mobileone_s1': _cfg( 'mobileone_s1.apple_in1k': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar', hf_hub_id='timm/',
crop_pct=0.9, crop_pct=0.9,
), ),
'mobileone_s2': _cfg( 'mobileone_s2.apple_in1k': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar', hf_hub_id='timm/',
crop_pct=0.9, crop_pct=0.9,
), ),
'mobileone_s3': _cfg( 'mobileone_s3.apple_in1k': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar', hf_hub_id='timm/',
crop_pct=0.9, crop_pct=0.9,
), ),
'mobileone_s4': _cfg( 'mobileone_s4.apple_in1k': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar', hf_hub_id='timm/',
crop_pct=0.9, crop_pct=0.9,
), ),
}) })
@ -1857,6 +1963,22 @@ def gernet_s(pretrained=False, **kwargs) -> ByobNet:
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs) return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a0(pretrained=False, **kwargs) -> ByobNet:
""" RepVGG-A0
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a0', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a1(pretrained=False, **kwargs) -> ByobNet:
""" RepVGG-A1
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a1', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_a2(pretrained=False, **kwargs) -> ByobNet: def repvgg_a2(pretrained=False, **kwargs) -> ByobNet:
""" RepVGG-A2 """ RepVGG-A2

View File

@ -37,8 +37,8 @@ class ConvNorm(torch.nn.Sequential):
b = bn.bias - bn.running_mean * bn.weight / \ b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5 (bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d( m = torch.nn.Conv2d(
w.size(1) * self.c.groups, w.size(0), w.shape[2:], w.size(1) * self.conv.groups, w.size(0), w.shape[2:],
stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
m.weight.data.copy_(w) m.weight.data.copy_(w)
m.bias.data.copy_(b) m.bias.data.copy_(b)
return m return m

View File

@ -1,9 +1,8 @@
# FastViT for PyTorch
# #
# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main # For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main
#
# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved. # Original work is copyright (C) 2023 Apple Inc. All Rights Reserved.
# #
import copy
import os import os
from functools import partial from functools import partial
from typing import List, Tuple, Optional, Union from typing import List, Tuple, Optional, Union
@ -12,8 +11,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -93,7 +94,7 @@ class MobileOneBlock(nn.Module):
# Re-parameterizable skip connection # Re-parameterizable skip connection
self.reparam_conv = None self.reparam_conv = None
self.rbr_skip = ( self.identity = (
nn.BatchNorm2d(num_features=in_chs) nn.BatchNorm2d(num_features=in_chs)
if out_chs == in_chs and stride == 1 if out_chs == in_chs and stride == 1
else None else None
@ -101,7 +102,7 @@ class MobileOneBlock(nn.Module):
# Re-parameterizable conv branches # Re-parameterizable conv branches
if num_conv_branches > 0: if num_conv_branches > 0:
self.rbr_conv = nn.ModuleList([ self.conv_kxk = nn.ModuleList([
ConvNormAct( ConvNormAct(
self.in_chs, self.in_chs,
self.out_chs, self.out_chs,
@ -112,12 +113,12 @@ class MobileOneBlock(nn.Module):
) for _ in range(self.num_conv_branches) ) for _ in range(self.num_conv_branches)
]) ])
else: else:
self.rbr_conv = None self.conv_kxk = None
# Re-parameterizable scale branch # Re-parameterizable scale branch
self.rbr_scale = None self.conv_scale = None
if kernel_size > 1 and use_scale_branch: if kernel_size > 1 and use_scale_branch:
self.rbr_scale = ConvNormAct( self.conv_scale = ConvNormAct(
self.in_chs, self.in_chs,
self.out_chs, self.out_chs,
kernel_size=1, kernel_size=1,
@ -135,20 +136,20 @@ class MobileOneBlock(nn.Module):
return self.act(self.se(self.reparam_conv(x))) return self.act(self.se(self.reparam_conv(x)))
# Multi-branched train-time forward pass. # Multi-branched train-time forward pass.
# Skip branch output # Identity branch output
identity_out = 0 identity_out = 0
if self.rbr_skip is not None: if self.identity is not None:
identity_out = self.rbr_skip(x) identity_out = self.identity(x)
# Scale branch output # Scale branch output
scale_out = 0 scale_out = 0
if self.rbr_scale is not None: if self.conv_scale is not None:
scale_out = self.rbr_scale(x) scale_out = self.conv_scale(x)
# Other branches # Other kxk conv branches
out = scale_out + identity_out out = scale_out + identity_out
if self.rbr_conv is not None: if self.conv_kxk is not None:
for rc in self.rbr_conv: for rc in self.conv_kxk:
out += rc(x) out += rc(x)
return self.act(self.se(out)) return self.act(self.se(out))
@ -176,13 +177,15 @@ class MobileOneBlock(nn.Module):
self.reparam_conv.bias.data = bias self.reparam_conv.bias.data = bias
# Delete un-used branches # Delete un-used branches
for para in self.parameters(): for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_() para.detach_()
self.__delattr__("rbr_conv") self.__delattr__("conv_kxk")
self.__delattr__("rbr_scale") self.__delattr__("conv_scale")
if hasattr(self, "rbr_skip"): if hasattr(self, "identity"):
self.__delattr__("rbr_skip") self.__delattr__("identity")
self.inference_mode = True self.inference_mode = True
@ -196,8 +199,8 @@ class MobileOneBlock(nn.Module):
# get weights and bias of scale branch # get weights and bias of scale branch
kernel_scale = 0 kernel_scale = 0
bias_scale = 0 bias_scale = 0
if self.rbr_scale is not None: if self.conv_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
# Pad scale branch kernel to match conv branch kernel size. # Pad scale branch kernel to match conv branch kernel size.
pad = self.kernel_size // 2 pad = self.kernel_size // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
@ -205,15 +208,15 @@ class MobileOneBlock(nn.Module):
# get weights and bias of skip branch # get weights and bias of skip branch
kernel_identity = 0 kernel_identity = 0
bias_identity = 0 bias_identity = 0
if self.rbr_skip is not None: if self.identity is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
# get weights and bias of conv branches # get weights and bias of conv branches
kernel_conv = 0 kernel_conv = 0
bias_conv = 0 bias_conv = 0
if self.rbr_conv is not None: if self.conv_kxk is not None:
for ix in range(self.num_conv_branches): for ix in range(self.num_conv_branches):
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
kernel_conv += _kernel kernel_conv += _kernel
bias_conv += _bias bias_conv += _bias
@ -233,7 +236,7 @@ class MobileOneBlock(nn.Module):
Returns: Returns:
Tuple of (kernel, bias) after fusing batchnorm. Tuple of (kernel, bias) after fusing batchnorm.
""" """
if isinstance(branch, nn.Sequential): if isinstance(branch, ConvNormAct):
kernel = branch.conv.weight kernel = branch.conv.weight
running_mean = branch.bn.running_mean running_mean = branch.bn.running_mean
running_var = branch.bn.running_var running_var = branch.bn.running_var
@ -306,7 +309,7 @@ class ReparamLargeKernelConv(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.small_kernel = small_kernel self.small_kernel = small_kernel
if inference_mode: if inference_mode:
self.lkb_reparam = create_conv2d( self.reparam_conv = create_conv2d(
in_chs, in_chs,
out_chs, out_chs,
kernel_size=kernel_size, kernel_size=kernel_size,
@ -316,8 +319,8 @@ class ReparamLargeKernelConv(nn.Module):
bias=True, bias=True,
) )
else: else:
self.lkb_reparam = None self.reparam_conv = None
self.lkb_origin = ConvNormAct( self.large_conv = ConvNormAct(
in_chs, in_chs,
out_chs, out_chs,
kernel_size=kernel_size, kernel_size=kernel_size,
@ -341,10 +344,10 @@ class ReparamLargeKernelConv(nn.Module):
self.act = act_layer() if act_layer is not None else nn.Identity() self.act = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.lkb_reparam is not None: if self.reparam_conv is not None:
out = self.lkb_reparam(x) out = self.reparam_conv(x)
else: else:
out = self.lkb_origin(x) out = self.large_conv(x)
if self.small_conv is not None: if self.small_conv is not None:
out = out + self.small_conv(x) out = out + self.small_conv(x)
out = self.act(out) out = self.act(out)
@ -357,7 +360,7 @@ class ReparamLargeKernelConv(nn.Module):
Returns: Returns:
Tuple of (kernel, bias) after fusing branches. Tuple of (kernel, bias) after fusing branches.
""" """
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn)
if hasattr(self, "small_conv"): if hasattr(self, "small_conv"):
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
eq_b += small_b eq_b += small_b
@ -374,19 +377,18 @@ class ReparamLargeKernelConv(nn.Module):
for inference. for inference.
""" """
eq_k, eq_b = self.get_kernel_bias() eq_k, eq_b = self.get_kernel_bias()
self.lkb_reparam = create_conv2d( self.reparam_conv = create_conv2d(
self.in_chs, self.in_chs,
self.out_chs, self.out_chs,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
stride=self.stride, stride=self.stride,
dilation=self.lkb_origin.conv.dilation,
groups=self.groups, groups=self.groups,
bias=True, bias=True,
) )
self.lkb_reparam.weight.data = eq_k self.reparam_conv.weight.data = eq_k
self.lkb_reparam.bias.data = eq_b self.reparam_conv.bias.data = eq_b
self.__delattr__("lkb_origin") self.__delattr__("large_conv")
if hasattr(self, "small_conv"): if hasattr(self, "small_conv"):
self.__delattr__("small_conv") self.__delattr__("small_conv")
@ -532,6 +534,8 @@ class PatchEmbed(nn.Module):
stride: int, stride: int,
in_chs: int, in_chs: int,
embed_dim: int, embed_dim: int,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False, inference_mode: bool = False,
) -> None: ) -> None:
"""Build patch embedding layer. """Build patch embedding layer.
@ -553,13 +557,14 @@ class PatchEmbed(nn.Module):
group_size=1, group_size=1,
small_kernel=3, small_kernel=3,
inference_mode=inference_mode, inference_mode=inference_mode,
act_layer=None, # activation was not used in original impl act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
), ),
MobileOneBlock( MobileOneBlock(
in_chs=embed_dim, in_chs=embed_dim,
out_chs=embed_dim, out_chs=embed_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
act_layer=act_layer,
inference_mode=inference_mode, inference_mode=inference_mode,
) )
) )
@ -569,6 +574,16 @@ class PatchEmbed(nn.Module):
return x return x
class LayerScale2d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class RepMixer(nn.Module): class RepMixer(nn.Module):
"""Reparameterizable token mixer. """Reparameterizable token mixer.
@ -580,7 +595,6 @@ class RepMixer(nn.Module):
self, self,
dim, dim,
kernel_size=3, kernel_size=3,
use_layer_scale=True,
layer_scale_init_value=1e-5, layer_scale_init_value=1e-5,
inference_mode: bool = False, inference_mode: bool = False,
): ):
@ -589,7 +603,6 @@ class RepMixer(nn.Module):
Args: Args:
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
kernel_size: Kernel size for spatial mixing. Default: 3 kernel_size: Kernel size for spatial mixing. Default: 3
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
layer_scale_init_value: Initial value for layer scale. Default: 1e-5 layer_scale_init_value: Initial value for layer scale. Default: 1e-5
inference_mode: If True, instantiates model in inference mode. Default: ``False`` inference_mode: If True, instantiates model in inference mode. Default: ``False``
""" """
@ -626,19 +639,16 @@ class RepMixer(nn.Module):
group_size=1, group_size=1,
use_act=False, use_act=False,
) )
self.use_layer_scale = use_layer_scale if layer_scale_init_value is not None:
if use_layer_scale: self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) else:
self.layer_scale = nn.Identity
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None: if self.reparam_conv is not None:
x = self.reparam_conv(x) x = self.reparam_conv(x)
return x
else: else:
if self.use_layer_scale: x = x + self.layer_scale(self.mixer(x) - self.norm(x))
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
else:
x = x + self.mixer(x) - self.norm(x)
return x return x
def reparameterize(self) -> None: def reparameterize(self) -> None:
@ -651,11 +661,11 @@ class RepMixer(nn.Module):
self.mixer.reparameterize() self.mixer.reparameterize()
self.norm.reparameterize() self.norm.reparameterize()
if self.use_layer_scale: if isinstance(self.layer_scale, LayerScale2d):
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * (
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
) )
b = torch.squeeze(self.layer_scale) * ( b = torch.squeeze(self.layer_scale.gamma) * (
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
) )
else: else:
@ -677,11 +687,12 @@ class RepMixer(nn.Module):
self.reparam_conv.weight.data = w self.reparam_conv.weight.data = w
self.reparam_conv.bias.data = b self.reparam_conv.bias.data = b
for para in self.parameters(): for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_() para.detach_()
self.__delattr__("mixer") self.__delattr__("mixer")
self.__delattr__("norm") self.__delattr__("norm")
if self.use_layer_scale:
self.__delattr__("layer_scale") self.__delattr__("layer_scale")
@ -708,19 +719,6 @@ class ConvMlp(nn.Module):
super().__init__() super().__init__()
out_chs = out_chs or in_chs out_chs = out_chs or in_chs
hidden_channels = hidden_channels or in_chs hidden_channels = hidden_channels or in_chs
# self.conv = nn.Sequential()
# self.conv.add_module(
# "conv",
# nn.Conv2d(
# in_chs,
# out_chs,
# kernel_size=7,
# padding=3,
# groups=in_chs,
# bias=False,
# ),
# )
# self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs))
self.conv = ConvNormAct( self.conv = ConvNormAct(
in_chs, in_chs,
out_chs, out_chs,
@ -750,7 +748,7 @@ class ConvMlp(nn.Module):
return x return x
class RepCPE(nn.Module): class RepConditionalPosEnc(nn.Module):
"""Implementation of conditional positional encoding. """Implementation of conditional positional encoding.
For more details refer to paper: For more details refer to paper:
@ -774,7 +772,7 @@ class RepCPE(nn.Module):
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
inference_mode: Flag to instantiate block in inference mode. Default: ``False`` inference_mode: Flag to instantiate block in inference mode. Default: ``False``
""" """
super(RepCPE, self).__init__() super(RepConditionalPosEnc, self).__init__()
if isinstance(spatial_shape, int): if isinstance(spatial_shape, int):
spatial_shape = tuple([spatial_shape] * 2) spatial_shape = tuple([spatial_shape] * 2)
assert isinstance(spatial_shape, Tuple), ( assert isinstance(spatial_shape, Tuple), (
@ -803,7 +801,7 @@ class RepCPE(nn.Module):
) )
else: else:
self.reparam_conv = None self.reparam_conv = None
self.pe = nn.Conv2d( self.pos_enc = nn.Conv2d(
self.dim, self.dim,
self.dim_out, self.dim_out,
spatial_shape, spatial_shape,
@ -816,9 +814,8 @@ class RepCPE(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None: if self.reparam_conv is not None:
x = self.reparam_conv(x) x = self.reparam_conv(x)
return x
else: else:
x = self.pe(x) + x x = self.pos_enc(x) + x
return x return x
def reparameterize(self) -> None: def reparameterize(self) -> None:
@ -831,8 +828,8 @@ class RepCPE(nn.Module):
self.spatial_shape[0], self.spatial_shape[0],
self.spatial_shape[1], self.spatial_shape[1],
), ),
dtype=self.pe.weight.dtype, dtype=self.pos_enc.weight.dtype,
device=self.pe.weight.device, device=self.pos_enc.weight.device,
) )
for i in range(self.dim): for i in range(self.dim):
kernel_value[ kernel_value[
@ -844,8 +841,8 @@ class RepCPE(nn.Module):
id_tensor = kernel_value id_tensor = kernel_value
# Reparameterize Id tensor and conv # Reparameterize Id tensor and conv
w_final = id_tensor + self.pe.weight w_final = id_tensor + self.pos_enc.weight
b_final = self.pe.bias b_final = self.pos_enc.bias
# Introduce reparam conv # Introduce reparam conv
self.reparam_conv = nn.Conv2d( self.reparam_conv = nn.Conv2d(
@ -860,9 +857,11 @@ class RepCPE(nn.Module):
self.reparam_conv.weight.data = w_final self.reparam_conv.weight.data = w_final
self.reparam_conv.bias.data = b_final self.reparam_conv.bias.data = b_final
for para in self.parameters(): for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_() para.detach_()
self.__delattr__("pe") self.__delattr__("pos_enc")
class RepMixerBlock(nn.Module): class RepMixerBlock(nn.Module):
@ -878,9 +877,8 @@ class RepMixerBlock(nn.Module):
kernel_size: int = 3, kernel_size: int = 3,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU, act_layer: nn.Module = nn.GELU,
drop: float = 0.0, proj_drop: float = 0.0,
drop_path: float = 0.0, drop_path: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5, layer_scale_init_value: float = 1e-5,
inference_mode: bool = False, inference_mode: bool = False,
): ):
@ -891,9 +889,8 @@ class RepMixerBlock(nn.Module):
kernel_size: Kernel size for repmixer. Default: 3 kernel_size: Kernel size for repmixer. Default: 3
mlp_ratio: MLP expansion ratio. Default: 4.0 mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU`` act_layer: Activation layer. Default: ``nn.GELU``
drop: Dropout rate. Default: 0.0 proj_drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0
use_layer_scale: Flag to turn on layer scale. Default: ``True``
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
inference_mode: Flag to instantiate block in inference mode. Default: ``False`` inference_mode: Flag to instantiate block in inference mode. Default: ``False``
""" """
@ -903,36 +900,25 @@ class RepMixerBlock(nn.Module):
self.token_mixer = RepMixer( self.token_mixer = RepMixer(
dim, dim,
kernel_size=kernel_size, kernel_size=kernel_size,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode, inference_mode=inference_mode,
) )
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( self.mlp = ConvMlp(
mlp_ratio
)
self.convffn = ConvMlp(
in_chs=dim, in_chs=dim,
hidden_channels=int(dim * mlp_ratio), hidden_channels=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
drop=drop, drop=proj_drop,
) )
if layer_scale_init_value is not None:
# Drop Path self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale = nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# Layer Scale
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
def forward(self, x): def forward(self, x):
if self.use_layer_scale:
x = self.token_mixer(x) x = self.token_mixer(x)
x = x + self.drop_path(self.layer_scale * self.convffn(x)) x = x + self.drop_path(self.layer_scale(self.mlp(x)))
else:
x = self.token_mixer(x)
x = x + self.drop_path(self.convffn(x))
return x return x
@ -949,9 +935,8 @@ class AttentionBlock(nn.Module):
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU, act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d, norm_layer: nn.Module = nn.BatchNorm2d,
drop: float = 0.0, proj_drop: float = 0.0,
drop_path: float = 0.0, drop_path: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5, layer_scale_init_value: float = 1e-5,
): ):
"""Build Attention Block. """Build Attention Block.
@ -961,9 +946,8 @@ class AttentionBlock(nn.Module):
mlp_ratio: MLP expansion ratio. Default: 4.0 mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU`` act_layer: Activation layer. Default: ``nn.GELU``
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
drop: Dropout rate. Default: 0.0 proj_drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0
use_layer_scale: Flag to turn on layer scale. Default: ``True``
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
""" """
@ -971,34 +955,27 @@ class AttentionBlock(nn.Module):
self.norm = norm_layer(dim) self.norm = norm_layer(dim)
self.token_mixer = Attention(dim=dim) self.token_mixer = Attention(dim=dim)
if layer_scale_init_value is not None:
self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale_1 = nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( self.mlp = ConvMlp(
mlp_ratio
)
mlp_hidden_dim = int(dim * mlp_ratio)
self.convffn = ConvMlp(
in_chs=dim, in_chs=dim,
hidden_channels=mlp_hidden_dim, hidden_channels=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
drop=drop, drop=proj_drop,
) )
if layer_scale_init_value is not None:
# Drop path self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() else:
self.layer_scale_2 = nn.Identity()
# Layer Scale self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)))
def forward(self, x): def forward(self, x):
if self.use_layer_scale: x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x))))
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) x = x + self.drop_path2(self.layer_scale_2(self.mlp(x)))
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
else:
x = x + self.drop_path(self.token_mixer(self.norm(x)))
x = x + self.drop_path(self.convffn(x))
return x return x
@ -1017,35 +994,38 @@ class FastVitStage(nn.Module):
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU, act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d, norm_layer: nn.Module = nn.BatchNorm2d,
drop_rate: float = 0.0, proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
use_layer_scale: bool = True, layer_scale_init_value: Optional[float] = 1e-5,
layer_scale_init_value: float = 1e-5, lkc_use_act=False,
inference_mode=False, inference_mode=False,
): ):
"""FastViT stage. """FastViT stage.
Args: Args:
dim: Number of embedding dimensions. dim: Number of embedding dimensions.
num_blocks: List containing number of blocks per stage. depth: Number of blocks in stage
token_mixer_type: Token mixer type. token_mixer_type: Token mixer type.
kernel_size: Kernel size for repmixer. kernel_size: Kernel size for repmixer.
mlp_ratio: MLP expansion ratio. mlp_ratio: MLP expansion ratio.
act_layer: Activation layer. act_layer: Activation layer.
norm_layer: Normalization layer. norm_layer: Normalization layer.
drop_rate: Dropout rate. proj_drop_rate: Dropout rate.
drop_path_rate: Drop path rate. drop_path_rate: Drop path rate.
use_layer_scale: Flag to turn on layer scale regularization.
layer_scale_init_value: Layer scale value at initialization. layer_scale_init_value: Layer scale value at initialization.
inference_mode: Flag to instantiate block in inference mode. inference_mode: Flag to instantiate block in inference mode.
""" """
super().__init__() super().__init__()
self.grad_checkpointing = False
if downsample: if downsample:
self.downsample = PatchEmbed( self.downsample = PatchEmbed(
patch_size=down_patch_size, patch_size=down_patch_size,
stride=down_stride, stride=down_stride,
in_chs=dim, in_chs=dim,
embed_dim=dim_out, embed_dim=dim_out,
act_layer=act_layer,
lkc_use_act=lkc_use_act,
inference_mode=inference_mode, inference_mode=inference_mode,
) )
else: else:
@ -1065,9 +1045,8 @@ class FastVitStage(nn.Module):
kernel_size=kernel_size, kernel_size=kernel_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
act_layer=act_layer, act_layer=act_layer,
drop=drop_rate, proj_drop=proj_drop_rate,
drop_path=drop_path_rate[block_idx], drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode, inference_mode=inference_mode,
)) ))
@ -1077,9 +1056,8 @@ class FastVitStage(nn.Module):
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
drop=drop_rate, proj_drop=proj_drop_rate,
drop_path=drop_path_rate[block_idx], drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
)) ))
else: else:
@ -1091,6 +1069,9 @@ class FastVitStage(nn.Module):
def forward(self, x): def forward(self, x):
x = self.downsample(x) x = self.downsample(x)
x = self.pos_emb(x) x = self.pos_emb(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x) x = self.blocks(x)
return x return x
@ -1116,21 +1097,25 @@ class FastVit(nn.Module):
down_patch_size: int = 7, down_patch_size: int = 7,
down_stride: int = 2, down_stride: int = 2,
drop_rate: float = 0.0, drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5, layer_scale_init_value: float = 1e-5,
fork_feat: bool = False, fork_feat: bool = False,
cls_ratio: float = 2.0, cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: nn.Module = nn.BatchNorm2d, norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU, act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False, inference_mode: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_classes = 0 if fork_feat else num_classes self.num_classes = 0 if fork_feat else num_classes
self.fork_feat = fork_feat self.fork_feat = fork_feat
self.global_pool = global_pool
self.feature_info = []
# Convolutional stem # Convolutional stem
self.patch_embed = convolutional_stem( self.stem = convolutional_stem(
in_chans, in_chans,
embed_dims[0], embed_dims[0],
inference_mode, inference_mode,
@ -1138,14 +1123,16 @@ class FastVit(nn.Module):
# Build the main stages of the network architecture # Build the main stages of the network architecture
prev_dim = embed_dims[0] prev_dim = embed_dims[0]
scale = 1
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
network = [] stages = []
for i in range(len(layers)): for i in range(len(layers)):
downsample = downsamples[i] or prev_dim != embed_dims[i]
stage = FastVitStage( stage = FastVitStage(
dim=prev_dim, dim=prev_dim,
dim_out=embed_dims[i], dim_out=embed_dims[i],
depth=layers[i], depth=layers[i],
downsample=downsamples[i] or prev_dim != embed_dims[i], downsample=downsample,
down_patch_size=down_patch_size, down_patch_size=down_patch_size,
down_stride=down_stride, down_stride=down_stride,
pos_emb_layer=pos_embs[i], pos_emb_layer=pos_embs[i],
@ -1154,16 +1141,19 @@ class FastVit(nn.Module):
mlp_ratio=mlp_ratios[i], mlp_ratio=mlp_ratios[i],
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
drop_rate=drop_rate, proj_drop_rate=drop_rate,
drop_path_rate=dpr[i], drop_path_rate=dpr[i],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
lkc_use_act=lkc_use_act,
inference_mode=inference_mode, inference_mode=inference_mode,
) )
network.append(stage) stages.append(stage)
prev_dim = embed_dims[i] prev_dim = embed_dims[i]
if downsample:
self.network = nn.Sequential(*network) scale *= 2
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = prev_dim
# For segmentation and detection, extract intermediate output # For segmentation and detection, extract intermediate output
if self.fork_feat: if self.fork_feat:
@ -1181,10 +1171,10 @@ class FastVit(nn.Module):
self.add_module(layer_name, layer) self.add_module(layer_name, layer)
else: else:
# Classifier head # Classifier head
self.gap = nn.AdaptiveAvgPool2d(output_size=1) final_features = int(embed_dims[-1] * cls_ratio)
self.conv_exp = MobileOneBlock( self.final_conv = MobileOneBlock(
in_chs=embed_dims[-1], in_chs=embed_dims[-1],
out_chs=int(embed_dims[-1] * cls_ratio), out_chs=final_features,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
group_size=1, group_size=1,
@ -1192,10 +1182,12 @@ class FastVit(nn.Module):
use_se=True, use_se=True,
num_conv_branches=1, num_conv_branches=1,
) )
self.head = ( self.num_features = final_features
nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) self.head = ClassifierHead(
if num_classes > 0 final_features,
else nn.Identity() num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
) )
self.apply(self._init_weights) self.apply(self._init_weights)
@ -1207,23 +1199,39 @@ class FastVit(nn.Module):
if isinstance(m, nn.Linear) and m.bias is not None: if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
@staticmethod @torch.jit.ignore
def _scrub_checkpoint(checkpoint, model): def no_weight_decay(self):
sterile_dict = {} return set()
for k1, v1 in checkpoint.items():
if k1 not in model.state_dict():
continue
if v1.shape == model.state_dict()[k1].shape:
sterile_dict[k1] = v1
return sterile_dict
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor: @torch.jit.ignore
x = self.patch_embed(x) def group_matcher(self, coarse=False):
return x return dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+).pos_emb', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: @torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
# input embedding
x = self.stem(x)
outs = [] outs = []
for idx, block in enumerate(self.network): for idx, block in enumerate(self.stages):
x = block(x) x = block(x)
if self.fork_feat: if self.fork_feat:
if idx in self.out_indices: if idx in self.out_indices:
@ -1236,20 +1244,16 @@ class FastVit(nn.Module):
# output only the features of last layer for image classification # output only the features of last layer for image classification
return x return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
x = self.final_conv(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# input embedding x = self.forward_features(x)
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
if self.fork_feat: if self.fork_feat:
# output features of four stages for dense prediction
return x return x
# for image classification x = self.forward_head(x)
x = self.conv_exp(x) return x
x = self.gap(x)
x = x.view(x.size(0), -1)
cls_out = self.head(x)
return cls_out
def _cfg(url="", **kwargs): def _cfg(url="", **kwargs):
@ -1257,81 +1261,63 @@ def _cfg(url="", **kwargs):
"url": url, "url": url,
"num_classes": 1000, "num_classes": 1000,
"input_size": (3, 256, 256), "input_size": (3, 256, 256),
"pool_size": None, "pool_size": (8, 8),
"crop_pct": 0.9, "crop_pct": 0.9,
"interpolation": "bicubic", "interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN, "mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD, "std": IMAGENET_DEFAULT_STD,
"classifier": "head", "classifier": "head.fc",
**kwargs, **kwargs,
} }
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
"fastvit_t8.apple_in1k": _cfg( "fastvit_t8.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar' hf_hub_id='timm/'),
),
"fastvit_t12.apple_in1k": _cfg( "fastvit_t12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t12.pth.tar' hf_hub_id='timm/'),
),
"fastvit_s12.apple_in1k": _cfg( "fastvit_s12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_s12.pth.tar'), hf_hub_id='timm/'),
"fastvit_sa12.apple_in1k": _cfg( "fastvit_sa12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa12.pth.tar'), hf_hub_id='timm/'),
"fastvit_sa24.apple_in1k": _cfg( "fastvit_sa24.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa24.pth.tar'), hf_hub_id='timm/'),
"fastvit_sa36.apple_in1k": _cfg( "fastvit_sa36.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa36.pth.tar'), hf_hub_id='timm/'),
"fastvit_ma36.apple_in1k": _cfg( "fastvit_ma36.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_ma36.pth.tar', hf_hub_id='timm/',
crop_pct=0.95 crop_pct=0.95
), ),
# "fastvit_t8.apple_dist_in1k": _cfg( "fastvit_t8.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t8.pth.tar' hf_hub_id='timm/'),
# ), "fastvit_t12.apple_dist_in1k": _cfg(
# "fastvit_t12.apple_dist_in1k": _cfg( hf_hub_id='timm/'),
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t12.pth.tar'
# ), "fastvit_s12.apple_dist_in1k": _cfg(
# hf_hub_id='timm/',),
# "fastvit_s12.apple_dist_in1k": _cfg( "fastvit_sa12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12.pth.tar'), hf_hub_id='timm/',),
# "fastvit_sa12.apple_dist_in1k": _cfg( "fastvit_sa24.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa12.pth.tar'), hf_hub_id='timm/',),
# "fastvit_sa24.apple_dist_in1k": _cfg( "fastvit_sa36.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa24.pth.tar'), hf_hub_id='timm/',),
# "fastvit_sa36.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa36.pth.tar'), "fastvit_ma36.apple_dist_in1k": _cfg(
# hf_hub_id='timm/',
# "fastvit_ma36.apple_dist_in1k": _cfg( crop_pct=0.95
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_ma36.pth.tar', ),
# crop_pct=0.95
# ),
}) })
def checkpoint_filter_fn(state_dict, model):
# FIXME temporary for remapping
state_dict = state_dict.get('state_dict', state_dict)
msd = model.state_dict()
out_dict = {}
for ka, kb, va, vb in zip(msd.keys(), state_dict.keys(), msd.values(), state_dict.values()):
if va.ndim == 4 and vb.ndim == 2:
vb = vb[:, :, None, None]
out_dict[ka] = vb
return out_dict
def _create_fastvit(variant, pretrained=False, **kwargs): def _create_fastvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg( model = build_model_with_cfg(
FastVit, FastVit,
variant, variant,
pretrained, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs **kwargs
) )
@ -1381,7 +1367,7 @@ def fastvit_sa12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2), layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1394,7 +1380,7 @@ def fastvit_sa24(pretrained=False, **kwargs):
layers=(4, 4, 12, 4), layers=(4, 4, 12, 4),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1407,11 +1393,12 @@ def fastvit_sa36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6), layers=(6, 6, 18, 6),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def fastvit_ma36(pretrained=False, **kwargs): def fastvit_ma36(pretrained=False, **kwargs):
"""Instantiate FastViT-MA36 model variant.""" """Instantiate FastViT-MA36 model variant."""
@ -1419,7 +1406,7 @@ def fastvit_ma36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6), layers=(6, 6, 18, 6),
embed_dims=(76, 152, 304, 608), embed_dims=(76, 152, 304, 608),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention") token_mixers=("repmixer", "repmixer", "repmixer", "attention")
) )
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -126,6 +126,9 @@ class RepGhostModule(nn.Module):
self.fusion_conv = [] self.fusion_conv = []
self.fusion_bn = [] self.fusion_bn = []
def reparameterize(self):
self.switch_to_deploy()
class RepGhostBottleneck(nn.Module): class RepGhostBottleneck(nn.Module):
""" RepGhost bottleneck w/ optional SE""" """ RepGhost bottleneck w/ optional SE"""

View File

@ -9,7 +9,7 @@ from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg, ParseKwargs from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
from .model_ema import ModelEma, ModelEmaV2 from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed from .random import random_seed
from .summary import update_summary, get_outdir from .summary import update_summary, get_outdir

View File

@ -3,6 +3,7 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import fnmatch import fnmatch
from copy import deepcopy
import torch import torch
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
@ -219,3 +220,21 @@ def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
See example in docstring for `freeze`. See example in docstring for `freeze`.
""" """
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
def reparameterize_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
if not inplace:
model = deepcopy(model)
def _fuse(m):
for child_name, child in m.named_children():
if hasattr(child, 'fuse'):
setattr(m, child_name, child.fuse())
elif hasattr(child, "reparameterize"):
child.reparameterize()
elif hasattr(child, "switch_to_deploy"):
child.switch_to_deploy()
_fuse(child)
_fuse(model)
return model

View File

@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
from timm.layers import apply_test_time_pool, set_fast_norm from timm.layers import apply_test_time_pool, set_fast_norm
from timm.models import create_model, load_checkpoint, is_model, list_models from timm.models import create_model, load_checkpoint, is_model, list_models
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
decay_batch_step, check_batch_size_retry, ParseKwargs decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model
try: try:
from apex import amp from apex import amp
@ -125,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
parser.add_argument('--reparam', default=False, action='store_true',
help='Reparameterize model')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
@ -207,6 +209,9 @@ def validate(args):
if args.checkpoint: if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema) load_checkpoint(model, args.checkpoint, args.use_ema)
if args.reparam:
model = reparameterize_model(model)
param_count = sum([m.numel() for m in model.parameters()]) param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count)) _logger.info('Model %s created, param count: %d' % (args.model, param_count))