From a503639bcce2f0d3e8c6a7459cabb7dd6aafa4c2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 May 2024 10:17:09 -0700 Subject: [PATCH] Add mobileclip fastvit model defs, support extra SE. Add forward_intermediates API to fastvit --- timm/models/fastvit.py | 405 +++++++++++++++++++++++++++++++---------- 1 file changed, 307 insertions(+), 98 deletions(-) diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 74b6cc28..7c918887 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -7,7 +7,7 @@ # import os from functools import partial -from typing import Tuple, Optional, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \ ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -40,19 +41,19 @@ class MobileOneBlock(nn.Module): """ def __init__( - self, - in_chs: int, - out_chs: int, - kernel_size: int, - stride: int = 1, - dilation: int = 1, - group_size: int = 0, - inference_mode: bool = False, - use_se: bool = False, - use_act: bool = True, - use_scale_branch: bool = True, - num_conv_branches: int = 1, - act_layer: nn.Module = nn.GELU, + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + inference_mode: bool = False, + use_se: bool = False, + use_act: bool = True, + use_scale_branch: bool = True, + num_conv_branches: int = 1, + act_layer: nn.Module = nn.GELU, ) -> None: """Construct a MobileOneBlock module. @@ -280,15 +281,16 @@ class ReparamLargeKernelConv(nn.Module): """ def __init__( - self, - in_chs: int, - out_chs: int, - kernel_size: int, - stride: int, - group_size: int, - small_kernel: Optional[int] = None, - inference_mode: bool = False, - act_layer: Optional[nn.Module] = None, + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int, + group_size: int, + small_kernel: Optional[int] = None, + use_se: bool = False, + act_layer: Optional[nn.Module] = None, + inference_mode: bool = False, ) -> None: """Construct a ReparamLargeKernelConv module. @@ -299,8 +301,8 @@ class ReparamLargeKernelConv(nn.Module): stride: Stride size. Default: 1 group_size: Group size. Default: 1 small_kernel: Kernel size of small kernel conv branch. - inference_mode: If True, instantiates model in inference mode. Default: ``False`` act_layer: Activation module. Default: ``nn.GELU`` + inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ super(ReparamLargeKernelConv, self).__init__() self.stride = stride @@ -342,6 +344,7 @@ class ReparamLargeKernelConv(nn.Module): groups=self.groups, apply_act=False, ) + self.se = SqueezeExcite(out_chs, rd_ratio=0.25) if use_se else nn.Identity() # FIXME output of this act was not used in original impl, likely due to bug self.act = act_layer() if act_layer is not None else nn.Identity() @@ -352,6 +355,7 @@ class ReparamLargeKernelConv(nn.Module): out = self.large_conv(x) if self.small_conv is not None: out = out + self.small_conv(x) + out = self.se(out) out = self.act(out) return out @@ -472,12 +476,12 @@ class Attention(nn.Module): fused_attn: torch.jit.Final[bool] def __init__( - self, - dim: int, - head_dim: int = 32, - qkv_bias: bool = False, - attn_drop: float = 0.0, - proj_drop: float = 0.0, + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, ) -> None: """Build MHSA module that can handle 3D or 4D input tensors. @@ -535,14 +539,15 @@ class PatchEmbed(nn.Module): """Convolutional patch embedding layer.""" def __init__( - self, - patch_size: int, - stride: int, - in_chs: int, - embed_dim: int, - act_layer: nn.Module = nn.GELU, - lkc_use_act: bool = False, - inference_mode: bool = False, + self, + patch_size: int, + stride: int, + in_chs: int, + embed_dim: int, + act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, + use_se: bool = False, + inference_mode: bool = False, ) -> None: """Build patch embedding layer. @@ -562,14 +567,16 @@ class PatchEmbed(nn.Module): stride=stride, group_size=1, small_kernel=3, - inference_mode=inference_mode, + use_se=use_se, act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act + inference_mode=inference_mode, ), MobileOneBlock( in_chs=embed_dim, out_chs=embed_dim, kernel_size=1, stride=1, + use_se=False, act_layer=act_layer, inference_mode=inference_mode, ) @@ -598,11 +605,11 @@ class RepMixer(nn.Module): """ def __init__( - self, - dim, - kernel_size=3, - layer_scale_init_value=1e-5, - inference_mode: bool = False, + self, + dim, + kernel_size=3, + layer_scale_init_value=1e-5, + inference_mode: bool = False, ): """Build RepMixer Module. @@ -648,7 +655,7 @@ class RepMixer(nn.Module): if layer_scale_init_value is not None: self.layer_scale = LayerScale2d(dim, layer_scale_init_value) else: - self.layer_scale = nn.Identity + self.layer_scale = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.reparam_conv is not None: @@ -706,12 +713,12 @@ class ConvMlp(nn.Module): """Convolutional FFN Module.""" def __init__( - self, - in_chs: int, - hidden_channels: Optional[int] = None, - out_chs: Optional[int] = None, - act_layer: nn.Module = nn.GELU, - drop: float = 0.0, + self, + in_chs: int, + hidden_channels: Optional[int] = None, + out_chs: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, ) -> None: """Build convolutional FFN module. @@ -764,11 +771,11 @@ class RepConditionalPosEnc(nn.Module): """ def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - spatial_shape: Union[int, Tuple[int, int]] = (7, 7), - inference_mode=False, + self, + dim: int, + dim_out: Optional[int] = None, + spatial_shape: Union[int, Tuple[int, int]] = (7, 7), + inference_mode=False, ) -> None: """Build reparameterizable conditional positional encoding @@ -878,15 +885,15 @@ class RepMixerBlock(nn.Module): """ def __init__( - self, - dim: int, - kernel_size: int = 3, - mlp_ratio: float = 4.0, - act_layer: nn.Module = nn.GELU, - proj_drop: float = 0.0, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-5, - inference_mode: bool = False, + self, + dim: int, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, ): """Build RepMixer Block. @@ -936,14 +943,14 @@ class AttentionBlock(nn.Module): """ def __init__( - self, - dim: int, - mlp_ratio: float = 4.0, - act_layer: nn.Module = nn.GELU, - norm_layer: nn.Module = nn.BatchNorm2d, - proj_drop: float = 0.0, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-5, + self, + dim: int, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, ): """Build Attention Block. @@ -993,6 +1000,7 @@ class FastVitStage(nn.Module): depth: int, token_mixer_type: str, downsample: bool = True, + se_downsample: bool = False, down_patch_size: int = 7, down_stride: int = 2, pos_emb_layer: Optional[nn.Module] = None, @@ -1030,6 +1038,7 @@ class FastVitStage(nn.Module): stride=down_stride, in_chs=dim, embed_dim=dim_out, + use_se=se_downsample, act_layer=act_layer, lkc_use_act=lkc_use_act, inference_mode=inference_mode, @@ -1090,29 +1099,30 @@ class FastVit(nn.Module): """ def __init__( - self, - in_chans: int = 3, - layers: Tuple[int, ...] = (2, 2, 6, 2), - token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), - embed_dims: Tuple[int, ...] = (64, 128, 256, 512), - mlp_ratios: Tuple[float, ...] = (4,) * 4, - downsamples: Tuple[bool, ...] = (False, True, True, True), - repmixer_kernel_size: int = 3, - num_classes: int = 1000, - pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, - down_patch_size: int = 7, - down_stride: int = 2, - drop_rate: float = 0.0, - proj_drop_rate: float = 0.0, - drop_path_rate: float = 0.0, - layer_scale_init_value: float = 1e-5, - fork_feat: bool = False, - cls_ratio: float = 2.0, - global_pool: str = 'avg', - norm_layer: nn.Module = nn.BatchNorm2d, - act_layer: nn.Module = nn.GELU, - lkc_use_act: bool = False, - inference_mode: bool = False, + self, + in_chans: int = 3, + layers: Tuple[int, ...] = (2, 2, 6, 2), + token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), + embed_dims: Tuple[int, ...] = (64, 128, 256, 512), + mlp_ratios: Tuple[float, ...] = (4,) * 4, + downsamples: Tuple[bool, ...] = (False, True, True, True), + se_downsamples: Tuple[bool, ...] = (False, False, False, False), + repmixer_kernel_size: int = 3, + num_classes: int = 1000, + pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, + down_patch_size: int = 7, + down_stride: int = 2, + drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-5, + lkc_use_act: bool = False, + fork_feat: bool = False, + cls_ratio: float = 2.0, + global_pool: str = 'avg', + norm_layer: nn.Module = nn.BatchNorm2d, + act_layer: nn.Module = nn.GELU, + inference_mode: bool = False, ) -> None: super().__init__() self.num_classes = 0 if fork_feat else num_classes @@ -1140,6 +1150,7 @@ class FastVit(nn.Module): dim_out=embed_dims[i], depth=layers[i], downsample=downsample, + se_downsample=se_downsamples[i], down_patch_size=down_patch_size, down_stride=down_stride, pos_emb_layer=pos_embs[i], @@ -1160,6 +1171,7 @@ class FastVit(nn.Module): scale *= 2 self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) + self.num_stages = len(self.stages) self.num_features = prev_dim # For segmentation and detection, extract intermediate output @@ -1236,6 +1248,66 @@ class FastVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = self.num_stages - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + feat_idx = 0 + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.final_conv(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: # input embedding x = self.stem(x) @@ -1297,8 +1369,7 @@ default_cfgs = generate_default_cfgs({ "fastvit_ma36.apple_in1k": _cfg( hf_hub_id='timm/', - crop_pct=0.95 - ), + crop_pct=0.95), "fastvit_t8.apple_dist_in1k": _cfg( hf_hub_id='timm/'), @@ -1318,15 +1389,111 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', crop_pct=0.95 ), + + "fastvit_mci0.apple_mclip": _cfg( + #hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), + "fastvit_mci1.apple_mclip": _cfg( + # hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), + "fastvit_mci2.apple_mclip": _cfg( + # hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), }) +def _checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'stem.0.conv_kxk.0.conv.weight' in state_dict: + return state_dict # non-original checkpoint, no remapping needed + + state_dict = state_dict.get('state_dict', state_dict) + if 'image_encoder.model.head.proj' in state_dict: + # remap MobileCLIP checkpoints + prefix = 'image_encoder.model.' + else: + prefix = '' + + import re + import bisect + + # find stage ends by locating downsample layers + stage_ends = [] + for k, v in state_dict.items(): + match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k) + if match: + stage_ends.append(int(match.group(2))) + stage_ends = list(sorted(set(stage_ends))) + + out_dict = {} + for k, v in state_dict.items(): + if prefix: + if prefix not in k: + continue + k = k.replace(prefix, '') + + # remap renamed layers + k = k.replace('patch_embed', 'stem') + k = k.replace('rbr_conv', 'conv_kxk') + k = k.replace('rbr_scale', 'conv_scale') + k = k.replace('rbr_skip', 'identity') + k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet + k = k.replace('lkb_origin', 'large_conv') + k = k.replace('convffn', 'mlp') + k = k.replace('se.reduce', 'se.fc1') + k = k.replace('se.expand', 'se.fc2') + k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k) + if k.endswith('layer_scale'): + k = k.replace('layer_scale', 'layer_scale.gamma') + k = k.replace('dist_head', 'head_dist') + if k.startswith('head.'): + if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear): + # if CLIP projection, map to head.fc w/ bias = zeros + k = k.replace('head.proj', 'head.fc.weight') + v = v.T + out_dict['head.fc.bias'] = torch.zeros(v.shape[0]) + else: + k = k.replace('head.', 'head.fc.') + + # remap flat sequential network to stages + match = re.match(r'^network\.(\d+)', k) + stage_idx, net_idx = None, None + if match: + net_idx = int(match.group(1)) + stage_idx = bisect.bisect_right(stage_ends, net_idx) + if stage_idx is not None: + net_prefix = f'network.{net_idx}' + stage_prefix = f'stages.{stage_idx}' + if net_prefix + '.proj' in k: + k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj') + elif net_prefix + '.pe' in k: + k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc') + else: + k = k.replace(net_prefix, stage_prefix + '.blocks') + + out_dict[k] = v + return out_dict + + def _create_fastvit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) model = build_model_with_cfg( FastVit, variant, pretrained, + pretrained_filter_fn=_checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) @@ -1419,3 +1586,45 @@ def fastvit_ma36(pretrained=False, **kwargs): token_mixers=("repmixer", "repmixer", "repmixer", "attention") ) return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci0(pretrained=False, **kwargs): + """Instantiate MCi0 model variant.""" + model_args = dict( + layers=(2, 6, 10, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci1(pretrained=False, **kwargs): + """Instantiate MCi1 model variant.""" + model_args = dict( + layers=(4, 12, 20, 4), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci2(pretrained=False, **kwargs): + """Instantiate MCi2 model variant.""" + model_args = dict( + layers=(4, 12, 24, 4), + embed_dims=(80, 160, 320, 640), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file