From 5242ba6edcf5dc977226be97089d1e99def01ee8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Aug 2023 14:16:43 -0700 Subject: [PATCH] MobileOne and FastViT weights on HF hub, more code cleanup and tweaks, features_only working. Add reparam flag to validate and benchmark, support reparm of all models with fuse(), reparameterize() or switch_to_deploy() methods on modules --- benchmark.py | 11 +- timm/models/byobnet.py | 252 ++++++++++++----- timm/models/efficientvit_msra.py | 4 +- timm/models/fastvit.py | 453 +++++++++++++++---------------- timm/models/repghost.py | 3 + timm/utils/__init__.py | 2 +- timm/utils/model.py | 19 ++ validate.py | 7 +- 8 files changed, 447 insertions(+), 304 deletions(-) diff --git a/benchmark.py b/benchmark.py index 2cce3e2c..c31708f5 100755 --- a/benchmark.py +++ b/benchmark.py @@ -22,7 +22,8 @@ from timm.data import resolve_data_config from timm.layers import set_fast_norm from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 -from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs +from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\ + reparameterize_model has_apex = False try: @@ -116,6 +117,8 @@ parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--reparam', default=False, action='store_true', + help='Reparameterize model') parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) # codegen (model compilation) options @@ -222,6 +225,7 @@ class BenchmarkRunner: torchscript=False, torchcompile=None, aot_autograd=False, + reparam=False, precision='float32', fuser='', num_warm_iter=10, @@ -252,10 +256,13 @@ class BenchmarkRunner: drop_block_rate=kwargs.pop('drop_block', None), **kwargs.pop('model_kwargs', {}), ) + if reparam: + self.model = reparameterize_model(self.model) self.model.to( device=self.device, dtype=self.model_dtype, - memory_format=torch.channels_last if self.channels_last else None) + memory_format=torch.channels_last if self.channels_last else None, + ) self.num_classes = self.model.num_classes self.param_count = count_params(self.model) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 713f4d3b..7464b901 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -12,6 +12,10 @@ RepVGG - repvgg_* Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT +MobileOne - mobileone_* +Paper: `MobileOne: An Improved One millisecond Mobile Backbone` - https://arxiv.org/abs/2206.04040 +Code and weights: https://github.com/apple/ml-mobileone, licensed MIT + In all cases the models have been modified to fit within the design of ByobNet. I've remapped the original weights and verified accuracies. @@ -468,8 +472,6 @@ class RepVggBlock(nn.Module): """ RepVGG Block. Adapted from impl at https://github.com/DingXiaoH/RepVGG - - This version does not currently support the deploy optimization. It is currently fixed in 'train' mode. """ def __init__( @@ -485,20 +487,34 @@ class RepVggBlock(nn.Module): layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., + inference_mode: bool = False ): super(RepVggBlock, self).__init__() + self.groups = groups = num_groups(group_size, in_chs) layers = layers or LayerFn() - groups = num_groups(group_size, in_chs) - #self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME temp for remapping - use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] - self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None - self.conv_kxk = layers.conv_norm_act( - in_chs, out_chs, kernel_size, - stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False, - ) - self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=in_chs, + out_channels=out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=True, + ) + else: + self.reparam_conv = None + use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] + self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None + self.conv_kxk = layers.conv_norm_act( + in_chs, out_chs, kernel_size, + stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False, + ) + self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() self.act = layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): @@ -511,16 +527,109 @@ class RepVggBlock(nn.Module): self.attn.reset_parameters() def forward(self, x): + if self.reparam_conv is not None: + return self.act(self.attn(self.reparam_conv(x))) + if self.identity is None: x = self.conv_1x1(x) + self.conv_kxk(x) else: identity = self.identity(x) x = self.conv_1x1(x) + self.conv_kxk(x) x = self.drop_path(x) # not in the paper / official impl, experimental - x = x + identity + x += identity x = self.attn(x) # no attn in the paper / official impl, experimental return self.act(x) + def reparameterize(self): + """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.reparam_conv is not None: + return + + kernel, bias = self._get_kernel_bias() + self.reparam_conv = nn.Conv2d( + in_channels=self.conv_kxk.conv.in_channels, + out_channels=self.conv_kxk.conv.out_channels, + kernel_size=self.conv_kxk.conv.kernel_size, + stride=self.conv_kxk.conv.stride, + padding=self.conv_kxk.conv.padding, + dilation=self.conv_kxk.conv.dilation, + groups=self.conv_kxk.conv.groups, + bias=True, + ) + self.reparam_conv.weight.data = kernel + self.reparam_conv.bias.data = bias + + # Delete un-used branches + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + self.__delattr__('conv_kxk') + self.__delattr__('conv_1x1') + self.__delattr__('identity') + self.__delattr__('drop_path') + + def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + """ + # get weights and bias of scale branch + kernel_1x1 = 0 + bias_1x1 = 0 + if self.conv_1x1 is not None: + kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.conv_kxk.conv.kernel_size[0] // 2 + kernel_1x1 = torch.nn.functional.pad(kernel_1x1, [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.identity is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity) + + # get weights and bias of conv branches + kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk) + + kernel_final = kernel_conv + kernel_1x1 + kernel_identity + bias_final = bias_conv + bias_1x1 + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + """ + if isinstance(branch, ConvNormAct): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + in_chs = self.conv_kxk.conv.in_channels + input_dim = in_chs // self.groups + kernel_size = self.conv_kxk.conv.kernel_size + kernel_value = torch.zeros_like(self.conv_kxk.conv.weight) + for i in range(in_chs): + kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + class MobileOneBlock(nn.Module): """ MobileOne building block. @@ -549,28 +658,11 @@ class MobileOneBlock(nn.Module): drop_path_rate: float = 0., ) -> None: """ Construct a MobileOneBlock module. - - :param in_chs: Number of channels in the input. - :param out_chs: Number of channels produced by the block. - :param kernel_size: Size of the convolution kernel. - :param stride: Stride size. - :param dilation: Kernel dilation factor. - :param groups: Group number. - :param inference_mode: If True, instantiates model in inference mode. - :param use_se: Whether to use SE-ReLU activations. - :param num_conv_branches: Number of linear conv branches. """ super(MobileOneBlock, self).__init__() - self.stride = stride - self.kernel_size = kernel_size - self.in_channels = in_chs - self.out_channels = out_chs self.num_conv_branches = num_conv_branches + self.groups = groups = num_groups(group_size, in_chs) layers = layers or LayerFn() - groups = num_groups(group_size, in_chs) - - # Check if SE-ReLU is requested - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME move after remap if inference_mode: self.reparam_conv = nn.Conv2d( @@ -602,7 +694,9 @@ class MobileOneBlock(nn.Module): self.conv_scale = layers.conv_norm_act( in_chs, out_chs, kernel_size=1, stride=stride, groups=groups, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) self.act = layers.act(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -623,9 +717,11 @@ class MobileOneBlock(nn.Module): scale_out = self.conv_scale(x) # Other branches - out = scale_out + identity_out + out = scale_out for ck in self.conv_kxk: out += ck(x) + out = self.drop_path(out) + out += identity_out return self.act(self.attn(out)) @@ -652,18 +748,18 @@ class MobileOneBlock(nn.Module): self.reparam_conv.bias.data = bias # Delete un-used branches - for para in self.parameters(): + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue para.detach_() self.__delattr__('conv_kxk') self.__delattr__('conv_scale') - if hasattr(self, 'identity'): - self.__delattr__('identity') + self.__delattr__('identity') + self.__delattr__('drop_path') def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: """ Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 - - :return: Tuple of (kernel, bias) after fusing branches. """ # get weights and bias of scale branch kernel_scale = 0 @@ -671,7 +767,7 @@ class MobileOneBlock(nn.Module): if self.conv_scale is not None: kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale) # Pad scale branch kernel to match conv branch kernel size. - pad = self.kernel_size // 2 + pad = self.conv_kxk[0].conv.kernel_size[0] // 2 kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) # get weights and bias of skip branch @@ -695,9 +791,6 @@ class MobileOneBlock(nn.Module): def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: """ Method to fuse batchnorm layer with preceeding conv layer. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 - - :param branch: - :return: Tuple of (kernel, bias) after fusing batchnorm. """ if isinstance(branch, ConvNormAct): kernel = branch.conv.weight @@ -709,16 +802,12 @@ class MobileOneBlock(nn.Module): else: assert isinstance(branch, nn.BatchNorm2d) if not hasattr(self, 'id_tensor'): - input_dim = self.in_channels // self.groups - kernel_value = torch.zeros( - (self.in_channels, - input_dim, - self.kernel_size, - self.kernel_size), - dtype=branch.weight.dtype, - device=branch.weight.device) - for i in range(self.in_channels): - kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1 + in_chs = self.conv_kxk[0].conv.in_channels + input_dim = in_chs // self.groups + kernel_size = self.conv_kxk[0].conv.kernel_size + kernel_value = torch.zeros_like(self.conv_kxk[0].conv.weight) + for i in range(in_chs): + kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1 self.id_tensor = kernel_value kernel = self.id_tensor running_mean = branch.running_mean @@ -1226,6 +1315,16 @@ model_cfgs = dict( num_features=1920, ), + repvgg_a0=ByoModelCfg( + blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(0.75, 0.75, 0.75, 2.5)), + stem_type='rep', + stem_chs=48, + ), + repvgg_a1=ByoModelCfg( + blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1, 1, 1, 2.5)), + stem_type='rep', + stem_chs=64, + ), repvgg_a2=ByoModelCfg( blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)), stem_type='rep', @@ -1643,15 +1742,11 @@ model_cfgs = dict( ), ) -# FIXME temporary for mobileone remap -from .fastvit import checkpoint_filter_fn - def _create_byobnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant], - pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), **kwargs) @@ -1683,6 +1778,12 @@ default_cfgs = generate_default_cfgs({ 'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)), # RepVGG weights + 'repvgg_a0.rvgg_in1k': _cfg( + hf_hub_id='timm/', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), + 'repvgg_a1.rvgg_in1k': _cfg( + hf_hub_id='timm/', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), 'repvgg_a2.rvgg_in1k': _cfg( hf_hub_id='timm/', first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), @@ -1707,6 +1808,11 @@ default_cfgs = generate_default_cfgs({ 'repvgg_b3g4.rvgg_in1k': _cfg( hf_hub_id='timm/', first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), + 'repvgg_d2se.rvgg_in1k': _cfg( + hf_hub_id='timm/', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, + ), # experimental ResNet configs 'resnet51q.ra2_in1k': _cfg( @@ -1810,24 +1916,24 @@ default_cfgs = generate_default_cfgs({ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0), - 'mobileone_s0': _cfg( - url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar', + 'mobileone_s0.apple_in1k': _cfg( + hf_hub_id='timm/', crop_pct=0.875, ), - 'mobileone_s1': _cfg( - url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar', + 'mobileone_s1.apple_in1k': _cfg( + hf_hub_id='timm/', crop_pct=0.9, ), - 'mobileone_s2': _cfg( - url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar', + 'mobileone_s2.apple_in1k': _cfg( + hf_hub_id='timm/', crop_pct=0.9, ), - 'mobileone_s3': _cfg( - url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar', + 'mobileone_s3.apple_in1k': _cfg( + hf_hub_id='timm/', crop_pct=0.9, ), - 'mobileone_s4': _cfg( - url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar', + 'mobileone_s4.apple_in1k': _cfg( + hf_hub_id='timm/', crop_pct=0.9, ), }) @@ -1857,6 +1963,22 @@ def gernet_s(pretrained=False, **kwargs) -> ByobNet: return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs) +@register_model +def repvgg_a0(pretrained=False, **kwargs) -> ByobNet: + """ RepVGG-A0 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_a0', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_a1(pretrained=False, **kwargs) -> ByobNet: + """ RepVGG-A1 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_a1', pretrained=pretrained, **kwargs) + + @register_model def repvgg_a2(pretrained=False, **kwargs) -> ByobNet: """ RepVGG-A2 diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 8940df0f..0edb09c2 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -37,8 +37,8 @@ class ConvNorm(torch.nn.Sequential): b = bn.bias - bn.running_mean * bn.weight / \ (bn.running_var + bn.eps)**0.5 m = torch.nn.Conv2d( - w.size(1) * self.c.groups, w.size(0), w.shape[2:], - stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + w.size(1) * self.conv.groups, w.size(0), w.shape[2:], + stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups) m.weight.data.copy_(w) m.bias.data.copy_(b) return m diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 8e5051de..30075d68 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -1,9 +1,8 @@ +# FastViT for PyTorch # # For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main -# # Original work is copyright (C) 2023 Apple Inc. All Rights Reserved. # -import copy import os from functools import partial from typing import List, Tuple, Optional, Union @@ -12,8 +11,10 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn +from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \ + ClassifierHead from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -93,7 +94,7 @@ class MobileOneBlock(nn.Module): # Re-parameterizable skip connection self.reparam_conv = None - self.rbr_skip = ( + self.identity = ( nn.BatchNorm2d(num_features=in_chs) if out_chs == in_chs and stride == 1 else None @@ -101,7 +102,7 @@ class MobileOneBlock(nn.Module): # Re-parameterizable conv branches if num_conv_branches > 0: - self.rbr_conv = nn.ModuleList([ + self.conv_kxk = nn.ModuleList([ ConvNormAct( self.in_chs, self.out_chs, @@ -112,12 +113,12 @@ class MobileOneBlock(nn.Module): ) for _ in range(self.num_conv_branches) ]) else: - self.rbr_conv = None + self.conv_kxk = None # Re-parameterizable scale branch - self.rbr_scale = None + self.conv_scale = None if kernel_size > 1 and use_scale_branch: - self.rbr_scale = ConvNormAct( + self.conv_scale = ConvNormAct( self.in_chs, self.out_chs, kernel_size=1, @@ -135,20 +136,20 @@ class MobileOneBlock(nn.Module): return self.act(self.se(self.reparam_conv(x))) # Multi-branched train-time forward pass. - # Skip branch output + # Identity branch output identity_out = 0 - if self.rbr_skip is not None: - identity_out = self.rbr_skip(x) + if self.identity is not None: + identity_out = self.identity(x) # Scale branch output scale_out = 0 - if self.rbr_scale is not None: - scale_out = self.rbr_scale(x) + if self.conv_scale is not None: + scale_out = self.conv_scale(x) - # Other branches + # Other kxk conv branches out = scale_out + identity_out - if self.rbr_conv is not None: - for rc in self.rbr_conv: + if self.conv_kxk is not None: + for rc in self.conv_kxk: out += rc(x) return self.act(self.se(out)) @@ -176,13 +177,15 @@ class MobileOneBlock(nn.Module): self.reparam_conv.bias.data = bias # Delete un-used branches - for para in self.parameters(): + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue para.detach_() - self.__delattr__("rbr_conv") - self.__delattr__("rbr_scale") - if hasattr(self, "rbr_skip"): - self.__delattr__("rbr_skip") + self.__delattr__("conv_kxk") + self.__delattr__("conv_scale") + if hasattr(self, "identity"): + self.__delattr__("identity") self.inference_mode = True @@ -196,8 +199,8 @@ class MobileOneBlock(nn.Module): # get weights and bias of scale branch kernel_scale = 0 bias_scale = 0 - if self.rbr_scale is not None: - kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) + if self.conv_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale) # Pad scale branch kernel to match conv branch kernel size. pad = self.kernel_size // 2 kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) @@ -205,15 +208,15 @@ class MobileOneBlock(nn.Module): # get weights and bias of skip branch kernel_identity = 0 bias_identity = 0 - if self.rbr_skip is not None: - kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) + if self.identity is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity) # get weights and bias of conv branches kernel_conv = 0 bias_conv = 0 - if self.rbr_conv is not None: + if self.conv_kxk is not None: for ix in range(self.num_conv_branches): - _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) + _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix]) kernel_conv += _kernel bias_conv += _bias @@ -233,7 +236,7 @@ class MobileOneBlock(nn.Module): Returns: Tuple of (kernel, bias) after fusing batchnorm. """ - if isinstance(branch, nn.Sequential): + if isinstance(branch, ConvNormAct): kernel = branch.conv.weight running_mean = branch.bn.running_mean running_var = branch.bn.running_var @@ -306,7 +309,7 @@ class ReparamLargeKernelConv(nn.Module): self.kernel_size = kernel_size self.small_kernel = small_kernel if inference_mode: - self.lkb_reparam = create_conv2d( + self.reparam_conv = create_conv2d( in_chs, out_chs, kernel_size=kernel_size, @@ -316,8 +319,8 @@ class ReparamLargeKernelConv(nn.Module): bias=True, ) else: - self.lkb_reparam = None - self.lkb_origin = ConvNormAct( + self.reparam_conv = None + self.large_conv = ConvNormAct( in_chs, out_chs, kernel_size=kernel_size, @@ -341,10 +344,10 @@ class ReparamLargeKernelConv(nn.Module): self.act = act_layer() if act_layer is not None else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.lkb_reparam is not None: - out = self.lkb_reparam(x) + if self.reparam_conv is not None: + out = self.reparam_conv(x) else: - out = self.lkb_origin(x) + out = self.large_conv(x) if self.small_conv is not None: out = out + self.small_conv(x) out = self.act(out) @@ -357,7 +360,7 @@ class ReparamLargeKernelConv(nn.Module): Returns: Tuple of (kernel, bias) after fusing branches. """ - eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn) if hasattr(self, "small_conv"): small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) eq_b += small_b @@ -374,19 +377,18 @@ class ReparamLargeKernelConv(nn.Module): for inference. """ eq_k, eq_b = self.get_kernel_bias() - self.lkb_reparam = create_conv2d( + self.reparam_conv = create_conv2d( self.in_chs, self.out_chs, kernel_size=self.kernel_size, stride=self.stride, - dilation=self.lkb_origin.conv.dilation, groups=self.groups, bias=True, ) - self.lkb_reparam.weight.data = eq_k - self.lkb_reparam.bias.data = eq_b - self.__delattr__("lkb_origin") + self.reparam_conv.weight.data = eq_k + self.reparam_conv.bias.data = eq_b + self.__delattr__("large_conv") if hasattr(self, "small_conv"): self.__delattr__("small_conv") @@ -532,6 +534,8 @@ class PatchEmbed(nn.Module): stride: int, in_chs: int, embed_dim: int, + act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, inference_mode: bool = False, ) -> None: """Build patch embedding layer. @@ -553,13 +557,14 @@ class PatchEmbed(nn.Module): group_size=1, small_kernel=3, inference_mode=inference_mode, - act_layer=None, # activation was not used in original impl + act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act ), MobileOneBlock( in_chs=embed_dim, out_chs=embed_dim, kernel_size=1, stride=1, + act_layer=act_layer, inference_mode=inference_mode, ) ) @@ -569,6 +574,16 @@ class PatchEmbed(nn.Module): return x +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + class RepMixer(nn.Module): """Reparameterizable token mixer. @@ -580,7 +595,6 @@ class RepMixer(nn.Module): self, dim, kernel_size=3, - use_layer_scale=True, layer_scale_init_value=1e-5, inference_mode: bool = False, ): @@ -589,7 +603,6 @@ class RepMixer(nn.Module): Args: dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. kernel_size: Kernel size for spatial mixing. Default: 3 - use_layer_scale: If True, learnable layer scale is used. Default: ``True`` layer_scale_init_value: Initial value for layer scale. Default: 1e-5 inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ @@ -626,20 +639,17 @@ class RepMixer(nn.Module): group_size=1, use_act=False, ) - self.use_layer_scale = use_layer_scale - if use_layer_scale: - self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale = nn.Identity def forward(self, x: torch.Tensor) -> torch.Tensor: if self.reparam_conv is not None: x = self.reparam_conv(x) - return x else: - if self.use_layer_scale: - x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) - else: - x = x + self.mixer(x) - self.norm(x) - return x + x = x + self.layer_scale(self.mixer(x) - self.norm(x)) + return x def reparameterize(self) -> None: """Reparameterize mixer and norm into a single @@ -651,11 +661,11 @@ class RepMixer(nn.Module): self.mixer.reparameterize() self.norm.reparameterize() - if self.use_layer_scale: - w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( + if isinstance(self.layer_scale, LayerScale2d): + w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * ( self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) - b = torch.squeeze(self.layer_scale) * ( + b = torch.squeeze(self.layer_scale.gamma) * ( self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias ) else: @@ -677,12 +687,13 @@ class RepMixer(nn.Module): self.reparam_conv.weight.data = w self.reparam_conv.bias.data = b - for para in self.parameters(): + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue para.detach_() self.__delattr__("mixer") self.__delattr__("norm") - if self.use_layer_scale: - self.__delattr__("layer_scale") + self.__delattr__("layer_scale") class ConvMlp(nn.Module): @@ -708,19 +719,6 @@ class ConvMlp(nn.Module): super().__init__() out_chs = out_chs or in_chs hidden_channels = hidden_channels or in_chs - # self.conv = nn.Sequential() - # self.conv.add_module( - # "conv", - # nn.Conv2d( - # in_chs, - # out_chs, - # kernel_size=7, - # padding=3, - # groups=in_chs, - # bias=False, - # ), - # ) - # self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs)) self.conv = ConvNormAct( in_chs, out_chs, @@ -750,7 +748,7 @@ class ConvMlp(nn.Module): return x -class RepCPE(nn.Module): +class RepConditionalPosEnc(nn.Module): """Implementation of conditional positional encoding. For more details refer to paper: @@ -774,7 +772,7 @@ class RepCPE(nn.Module): spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ - super(RepCPE, self).__init__() + super(RepConditionalPosEnc, self).__init__() if isinstance(spatial_shape, int): spatial_shape = tuple([spatial_shape] * 2) assert isinstance(spatial_shape, Tuple), ( @@ -803,7 +801,7 @@ class RepCPE(nn.Module): ) else: self.reparam_conv = None - self.pe = nn.Conv2d( + self.pos_enc = nn.Conv2d( self.dim, self.dim_out, spatial_shape, @@ -816,10 +814,9 @@ class RepCPE(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: if self.reparam_conv is not None: x = self.reparam_conv(x) - return x else: - x = self.pe(x) + x - return x + x = self.pos_enc(x) + x + return x def reparameterize(self) -> None: # Build equivalent Id tensor @@ -831,8 +828,8 @@ class RepCPE(nn.Module): self.spatial_shape[0], self.spatial_shape[1], ), - dtype=self.pe.weight.dtype, - device=self.pe.weight.device, + dtype=self.pos_enc.weight.dtype, + device=self.pos_enc.weight.device, ) for i in range(self.dim): kernel_value[ @@ -844,8 +841,8 @@ class RepCPE(nn.Module): id_tensor = kernel_value # Reparameterize Id tensor and conv - w_final = id_tensor + self.pe.weight - b_final = self.pe.bias + w_final = id_tensor + self.pos_enc.weight + b_final = self.pos_enc.bias # Introduce reparam conv self.reparam_conv = nn.Conv2d( @@ -860,9 +857,11 @@ class RepCPE(nn.Module): self.reparam_conv.weight.data = w_final self.reparam_conv.bias.data = b_final - for para in self.parameters(): + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue para.detach_() - self.__delattr__("pe") + self.__delattr__("pos_enc") class RepMixerBlock(nn.Module): @@ -878,9 +877,8 @@ class RepMixerBlock(nn.Module): kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, - drop: float = 0.0, + proj_drop: float = 0.0, drop_path: float = 0.0, - use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, ): @@ -891,9 +889,8 @@ class RepMixerBlock(nn.Module): kernel_size: Kernel size for repmixer. Default: 3 mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` - drop: Dropout rate. Default: 0.0 + proj_drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 - use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ @@ -903,36 +900,25 @@ class RepMixerBlock(nn.Module): self.token_mixer = RepMixer( dim, kernel_size=kernel_size, - use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) - assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( - mlp_ratio - ) - self.convffn = ConvMlp( + self.mlp = ConvMlp( in_chs=dim, hidden_channels=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, ) - - # Drop Path + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale = nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - # Layer Scale - self.use_layer_scale = use_layer_scale - if use_layer_scale: - self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) - def forward(self, x): - if self.use_layer_scale: - x = self.token_mixer(x) - x = x + self.drop_path(self.layer_scale * self.convffn(x)) - else: - x = self.token_mixer(x) - x = x + self.drop_path(self.convffn(x)) + x = self.token_mixer(x) + x = x + self.drop_path(self.layer_scale(self.mlp(x))) return x @@ -949,9 +935,8 @@ class AttentionBlock(nn.Module): mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, - drop: float = 0.0, + proj_drop: float = 0.0, drop_path: float = 0.0, - use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, ): """Build Attention Block. @@ -961,9 +946,8 @@ class AttentionBlock(nn.Module): mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` - drop: Dropout rate. Default: 0.0 + proj_drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 - use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 """ @@ -971,34 +955,27 @@ class AttentionBlock(nn.Module): self.norm = norm_layer(dim) self.token_mixer = Attention(dim=dim) + if layer_scale_init_value is not None: + self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale_1 = nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( - mlp_ratio - ) - mlp_hidden_dim = int(dim * mlp_ratio) - self.convffn = ConvMlp( + self.mlp = ConvMlp( in_chs=dim, - hidden_channels=mlp_hidden_dim, + hidden_channels=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, ) - - # Drop path - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - # Layer Scale - self.use_layer_scale = use_layer_scale - if use_layer_scale: - self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) - self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) + if layer_scale_init_value is not None: + self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale_2 = nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): - if self.use_layer_scale: - x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) - x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) - else: - x = x + self.drop_path(self.token_mixer(self.norm(x))) - x = x + self.drop_path(self.convffn(x)) + x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x)))) + x = x + self.drop_path2(self.layer_scale_2(self.mlp(x))) return x @@ -1017,35 +994,38 @@ class FastVitStage(nn.Module): mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, - drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, drop_path_rate: float = 0.0, - use_layer_scale: bool = True, - layer_scale_init_value: float = 1e-5, + layer_scale_init_value: Optional[float] = 1e-5, + lkc_use_act=False, inference_mode=False, ): """FastViT stage. Args: dim: Number of embedding dimensions. - num_blocks: List containing number of blocks per stage. + depth: Number of blocks in stage token_mixer_type: Token mixer type. kernel_size: Kernel size for repmixer. mlp_ratio: MLP expansion ratio. act_layer: Activation layer. norm_layer: Normalization layer. - drop_rate: Dropout rate. + proj_drop_rate: Dropout rate. drop_path_rate: Drop path rate. - use_layer_scale: Flag to turn on layer scale regularization. layer_scale_init_value: Layer scale value at initialization. inference_mode: Flag to instantiate block in inference mode. """ super().__init__() + self.grad_checkpointing = False + if downsample: self.downsample = PatchEmbed( patch_size=down_patch_size, stride=down_stride, in_chs=dim, embed_dim=dim_out, + act_layer=act_layer, + lkc_use_act=lkc_use_act, inference_mode=inference_mode, ) else: @@ -1065,9 +1045,8 @@ class FastVitStage(nn.Module): kernel_size=kernel_size, mlp_ratio=mlp_ratio, act_layer=act_layer, - drop=drop_rate, + proj_drop=proj_drop_rate, drop_path=drop_path_rate[block_idx], - use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, )) @@ -1077,9 +1056,8 @@ class FastVitStage(nn.Module): mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, - drop=drop_rate, + proj_drop=proj_drop_rate, drop_path=drop_path_rate[block_idx], - use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, )) else: @@ -1091,7 +1069,10 @@ class FastVitStage(nn.Module): def forward(self, x): x = self.downsample(x) x = self.pos_emb(x) - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x @@ -1116,21 +1097,25 @@ class FastVit(nn.Module): down_patch_size: int = 7, down_stride: int = 2, drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, drop_path_rate: float = 0.0, - use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, fork_feat: bool = False, cls_ratio: float = 2.0, + global_pool: str = 'avg', norm_layer: nn.Module = nn.BatchNorm2d, act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, inference_mode: bool = False, ) -> None: super().__init__() self.num_classes = 0 if fork_feat else num_classes self.fork_feat = fork_feat + self.global_pool = global_pool + self.feature_info = [] # Convolutional stem - self.patch_embed = convolutional_stem( + self.stem = convolutional_stem( in_chans, embed_dims[0], inference_mode, @@ -1138,14 +1123,16 @@ class FastVit(nn.Module): # Build the main stages of the network architecture prev_dim = embed_dims[0] + scale = 1 dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] - network = [] + stages = [] for i in range(len(layers)): + downsample = downsamples[i] or prev_dim != embed_dims[i] stage = FastVitStage( dim=prev_dim, dim_out=embed_dims[i], depth=layers[i], - downsample=downsamples[i] or prev_dim != embed_dims[i], + downsample=downsample, down_patch_size=down_patch_size, down_stride=down_stride, pos_emb_layer=pos_embs[i], @@ -1154,16 +1141,19 @@ class FastVit(nn.Module): mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, - drop_rate=drop_rate, + proj_drop_rate=drop_rate, drop_path_rate=dpr[i], - use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, + lkc_use_act=lkc_use_act, inference_mode=inference_mode, ) - network.append(stage) + stages.append(stage) prev_dim = embed_dims[i] - - self.network = nn.Sequential(*network) + if downsample: + scale *= 2 + self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + self.num_features = prev_dim # For segmentation and detection, extract intermediate output if self.fork_feat: @@ -1181,10 +1171,10 @@ class FastVit(nn.Module): self.add_module(layer_name, layer) else: # Classifier head - self.gap = nn.AdaptiveAvgPool2d(output_size=1) - self.conv_exp = MobileOneBlock( + final_features = int(embed_dims[-1] * cls_ratio) + self.final_conv = MobileOneBlock( in_chs=embed_dims[-1], - out_chs=int(embed_dims[-1] * cls_ratio), + out_chs=final_features, kernel_size=3, stride=1, group_size=1, @@ -1192,10 +1182,12 @@ class FastVit(nn.Module): use_se=True, num_conv_branches=1, ) - self.head = ( - nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) - if num_classes > 0 - else nn.Identity() + self.num_features = final_features + self.head = ClassifierHead( + final_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, ) self.apply(self._init_weights) @@ -1207,23 +1199,39 @@ class FastVit(nn.Module): if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) - @staticmethod - def _scrub_checkpoint(checkpoint, model): - sterile_dict = {} - for k1, v1 in checkpoint.items(): - if k1 not in model.state_dict(): - continue - if v1.shape == model.state_dict()[k1].shape: - sterile_dict[k1] = v1 - return sterile_dict + @torch.jit.ignore + def no_weight_decay(self): + return set() - def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor: - x = self.patch_embed(x) - return x + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+).pos_emb', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) - def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + # input embedding + x = self.stem(x) outs = [] - for idx, block in enumerate(self.network): + for idx, block in enumerate(self.stages): x = block(x) if self.fork_feat: if idx in self.out_indices: @@ -1236,20 +1244,16 @@ class FastVit(nn.Module): # output only the features of last layer for image classification return x + def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + x = self.final_conv(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: - # input embedding - x = self.forward_embeddings(x) - # through backbone - x = self.forward_tokens(x) + x = self.forward_features(x) if self.fork_feat: - # output features of four stages for dense prediction return x - # for image classification - x = self.conv_exp(x) - x = self.gap(x) - x = x.view(x.size(0), -1) - cls_out = self.head(x) - return cls_out + x = self.forward_head(x) + return x def _cfg(url="", **kwargs): @@ -1257,81 +1261,63 @@ def _cfg(url="", **kwargs): "url": url, "num_classes": 1000, "input_size": (3, 256, 256), - "pool_size": None, + "pool_size": (8, 8), "crop_pct": 0.9, "interpolation": "bicubic", "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, - "classifier": "head", + "classifier": "head.fc", **kwargs, } default_cfgs = generate_default_cfgs({ "fastvit_t8.apple_in1k": _cfg( - url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar' - ), + hf_hub_id='timm/'), "fastvit_t12.apple_in1k": _cfg( - url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t12.pth.tar' - ), + hf_hub_id='timm/'), "fastvit_s12.apple_in1k": _cfg( - url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_s12.pth.tar'), + hf_hub_id='timm/'), "fastvit_sa12.apple_in1k": _cfg( - url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa12.pth.tar'), + hf_hub_id='timm/'), "fastvit_sa24.apple_in1k": _cfg( - url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa24.pth.tar'), + hf_hub_id='timm/'), "fastvit_sa36.apple_in1k": _cfg( - url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa36.pth.tar'), + hf_hub_id='timm/'), "fastvit_ma36.apple_in1k": _cfg( - url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_ma36.pth.tar', + hf_hub_id='timm/', crop_pct=0.95 ), - # "fastvit_t8.apple_dist_in1k": _cfg( - # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t8.pth.tar' - # ), - # "fastvit_t12.apple_dist_in1k": _cfg( - # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t12.pth.tar' - # ), - # - # "fastvit_s12.apple_dist_in1k": _cfg( - # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12.pth.tar'), - # "fastvit_sa12.apple_dist_in1k": _cfg( - # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa12.pth.tar'), - # "fastvit_sa24.apple_dist_in1k": _cfg( - # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa24.pth.tar'), - # "fastvit_sa36.apple_dist_in1k": _cfg( - # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa36.pth.tar'), - # - # "fastvit_ma36.apple_dist_in1k": _cfg( - # url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_ma36.pth.tar', - # crop_pct=0.95 - # ), + "fastvit_t8.apple_dist_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_t12.apple_dist_in1k": _cfg( + hf_hub_id='timm/'), + + "fastvit_s12.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa12.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa24.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa36.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + + "fastvit_ma36.apple_dist_in1k": _cfg( + hf_hub_id='timm/', + crop_pct=0.95 + ), }) -def checkpoint_filter_fn(state_dict, model): - # FIXME temporary for remapping - state_dict = state_dict.get('state_dict', state_dict) - msd = model.state_dict() - out_dict = {} - for ka, kb, va, vb in zip(msd.keys(), state_dict.keys(), msd.values(), state_dict.values()): - if va.ndim == 4 and vb.ndim == 2: - vb = vb[:, :, None, None] - out_dict[ka] = vb - - return out_dict - - def _create_fastvit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) model = build_model_with_cfg( FastVit, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) @@ -1381,7 +1367,7 @@ def fastvit_sa12(pretrained=False, **kwargs): layers=(2, 2, 6, 2), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), - pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1394,7 +1380,7 @@ def fastvit_sa24(pretrained=False, **kwargs): layers=(4, 4, 12, 4), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), - pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1407,11 +1393,12 @@ def fastvit_sa36(pretrained=False, **kwargs): layers=(6, 6, 18, 6), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), - pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model def fastvit_ma36(pretrained=False, **kwargs): """Instantiate FastViT-MA36 model variant.""" @@ -1419,7 +1406,7 @@ def fastvit_ma36(pretrained=False, **kwargs): layers=(6, 6, 18, 6), embed_dims=(76, 152, 304, 608), mlp_ratios=(4, 4, 4, 4), - pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention") ) return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file diff --git a/timm/models/repghost.py b/timm/models/repghost.py index ae719c89..da697b70 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -126,6 +126,9 @@ class RepGhostModule(nn.Module): self.fusion_conv = [] self.fusion_bn = [] + def reparameterize(self): + self.switch_to_deploy() + class RepGhostBottleneck(nn.Module): """ RepGhost bottleneck w/ optional SE""" diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 7727adff..63fcf4c5 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -9,7 +9,7 @@ from .jit import set_jit_legacy, set_jit_fuser from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg, ParseKwargs -from .model import unwrap_model, get_state_dict, freeze, unfreeze +from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model from .model_ema import ModelEma, ModelEmaV2 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/model.py b/timm/utils/model.py index d74ee5b7..894453a8 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -3,6 +3,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import fnmatch +from copy import deepcopy import torch from torchvision.ops.misc import FrozenBatchNorm2d @@ -219,3 +220,21 @@ def unfreeze(root_module, submodules=[], include_bn_running_stats=True): See example in docstring for `freeze`. """ _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") + + +def reparameterize_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module: + if not inplace: + model = deepcopy(model) + + def _fuse(m): + for child_name, child in m.named_children(): + if hasattr(child, 'fuse'): + setattr(m, child_name, child.fuse()) + elif hasattr(child, "reparameterize"): + child.reparameterize() + elif hasattr(child, "switch_to_deploy"): + child.switch_to_deploy() + _fuse(child) + + _fuse(model) + return model diff --git a/validate.py b/validate.py index 794d1ae8..8798f80e 100755 --- a/validate.py +++ b/validate.py @@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa from timm.layers import apply_test_time_pool, set_fast_norm from timm.models import create_model, load_checkpoint, is_model, list_models from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ - decay_batch_step, check_batch_size_retry, ParseKwargs + decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model try: from apex import amp @@ -125,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--reparam', default=False, action='store_true', + help='Reparameterize model') parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) @@ -207,6 +209,9 @@ def validate(args): if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) + if args.reparam: + model = reparameterize_model(model) + param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count))