Add mobileclip fastvit model defs, support extra SE. Add forward_intermediates API to fastvit

This commit is contained in:
Ross Wightman 2024-05-30 10:17:09 -07:00
parent 5dce710101
commit a503639bcc

View File

@ -7,7 +7,7 @@
#
import os
from functools import partial
from typing import Tuple, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -40,19 +41,19 @@ class MobileOneBlock(nn.Module):
"""
def __init__(
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
group_size: int = 0,
inference_mode: bool = False,
use_se: bool = False,
use_act: bool = True,
use_scale_branch: bool = True,
num_conv_branches: int = 1,
act_layer: nn.Module = nn.GELU,
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
group_size: int = 0,
inference_mode: bool = False,
use_se: bool = False,
use_act: bool = True,
use_scale_branch: bool = True,
num_conv_branches: int = 1,
act_layer: nn.Module = nn.GELU,
) -> None:
"""Construct a MobileOneBlock module.
@ -280,15 +281,16 @@ class ReparamLargeKernelConv(nn.Module):
"""
def __init__(
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int,
group_size: int,
small_kernel: Optional[int] = None,
inference_mode: bool = False,
act_layer: Optional[nn.Module] = None,
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int,
group_size: int,
small_kernel: Optional[int] = None,
use_se: bool = False,
act_layer: Optional[nn.Module] = None,
inference_mode: bool = False,
) -> None:
"""Construct a ReparamLargeKernelConv module.
@ -299,8 +301,8 @@ class ReparamLargeKernelConv(nn.Module):
stride: Stride size. Default: 1
group_size: Group size. Default: 1
small_kernel: Kernel size of small kernel conv branch.
inference_mode: If True, instantiates model in inference mode. Default: ``False``
act_layer: Activation module. Default: ``nn.GELU``
inference_mode: If True, instantiates model in inference mode. Default: ``False``
"""
super(ReparamLargeKernelConv, self).__init__()
self.stride = stride
@ -342,6 +344,7 @@ class ReparamLargeKernelConv(nn.Module):
groups=self.groups,
apply_act=False,
)
self.se = SqueezeExcite(out_chs, rd_ratio=0.25) if use_se else nn.Identity()
# FIXME output of this act was not used in original impl, likely due to bug
self.act = act_layer() if act_layer is not None else nn.Identity()
@ -352,6 +355,7 @@ class ReparamLargeKernelConv(nn.Module):
out = self.large_conv(x)
if self.small_conv is not None:
out = out + self.small_conv(x)
out = self.se(out)
out = self.act(out)
return out
@ -472,12 +476,12 @@ class Attention(nn.Module):
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
"""Build MHSA module that can handle 3D or 4D input tensors.
@ -535,14 +539,15 @@ class PatchEmbed(nn.Module):
"""Convolutional patch embedding layer."""
def __init__(
self,
patch_size: int,
stride: int,
in_chs: int,
embed_dim: int,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False,
self,
patch_size: int,
stride: int,
in_chs: int,
embed_dim: int,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
use_se: bool = False,
inference_mode: bool = False,
) -> None:
"""Build patch embedding layer.
@ -562,14 +567,16 @@ class PatchEmbed(nn.Module):
stride=stride,
group_size=1,
small_kernel=3,
inference_mode=inference_mode,
use_se=use_se,
act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
inference_mode=inference_mode,
),
MobileOneBlock(
in_chs=embed_dim,
out_chs=embed_dim,
kernel_size=1,
stride=1,
use_se=False,
act_layer=act_layer,
inference_mode=inference_mode,
)
@ -598,11 +605,11 @@ class RepMixer(nn.Module):
"""
def __init__(
self,
dim,
kernel_size=3,
layer_scale_init_value=1e-5,
inference_mode: bool = False,
self,
dim,
kernel_size=3,
layer_scale_init_value=1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Module.
@ -648,7 +655,7 @@ class RepMixer(nn.Module):
if layer_scale_init_value is not None:
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale = nn.Identity
self.layer_scale = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None:
@ -706,12 +713,12 @@ class ConvMlp(nn.Module):
"""Convolutional FFN Module."""
def __init__(
self,
in_chs: int,
hidden_channels: Optional[int] = None,
out_chs: Optional[int] = None,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
self,
in_chs: int,
hidden_channels: Optional[int] = None,
out_chs: Optional[int] = None,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
) -> None:
"""Build convolutional FFN module.
@ -764,11 +771,11 @@ class RepConditionalPosEnc(nn.Module):
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False,
self,
dim: int,
dim_out: Optional[int] = None,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False,
) -> None:
"""Build reparameterizable conditional positional encoding
@ -878,15 +885,15 @@ class RepMixerBlock(nn.Module):
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
inference_mode: bool = False,
self,
dim: int,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Block.
@ -936,14 +943,14 @@ class AttentionBlock(nn.Module):
"""
def __init__(
self,
dim: int,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
self,
dim: int,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
):
"""Build Attention Block.
@ -993,6 +1000,7 @@ class FastVitStage(nn.Module):
depth: int,
token_mixer_type: str,
downsample: bool = True,
se_downsample: bool = False,
down_patch_size: int = 7,
down_stride: int = 2,
pos_emb_layer: Optional[nn.Module] = None,
@ -1030,6 +1038,7 @@ class FastVitStage(nn.Module):
stride=down_stride,
in_chs=dim,
embed_dim=dim_out,
use_se=se_downsample,
act_layer=act_layer,
lkc_use_act=lkc_use_act,
inference_mode=inference_mode,
@ -1090,29 +1099,30 @@ class FastVit(nn.Module):
"""
def __init__(
self,
in_chans: int = 3,
layers: Tuple[int, ...] = (2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
mlp_ratios: Tuple[float, ...] = (4,) * 4,
downsamples: Tuple[bool, ...] = (False, True, True, True),
repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-5,
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False,
self,
in_chans: int = 3,
layers: Tuple[int, ...] = (2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
mlp_ratios: Tuple[float, ...] = (4,) * 4,
downsamples: Tuple[bool, ...] = (False, True, True, True),
se_downsamples: Tuple[bool, ...] = (False, False, False, False),
repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-5,
lkc_use_act: bool = False,
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
inference_mode: bool = False,
) -> None:
super().__init__()
self.num_classes = 0 if fork_feat else num_classes
@ -1140,6 +1150,7 @@ class FastVit(nn.Module):
dim_out=embed_dims[i],
depth=layers[i],
downsample=downsample,
se_downsample=se_downsamples[i],
down_patch_size=down_patch_size,
down_stride=down_stride,
pos_emb_layer=pos_embs[i],
@ -1160,6 +1171,7 @@ class FastVit(nn.Module):
scale *= 2
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_stages = len(self.stages)
self.num_features = prev_dim
# For segmentation and detection, extract intermediate output
@ -1236,6 +1248,66 @@ class FastVit(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = self.num_stages - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
feat_idx = 0
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.final_conv(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
# input embedding
x = self.stem(x)
@ -1297,8 +1369,7 @@ default_cfgs = generate_default_cfgs({
"fastvit_ma36.apple_in1k": _cfg(
hf_hub_id='timm/',
crop_pct=0.95
),
crop_pct=0.95),
"fastvit_t8.apple_dist_in1k": _cfg(
hf_hub_id='timm/'),
@ -1318,15 +1389,111 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
crop_pct=0.95
),
"fastvit_mci0.apple_mclip": _cfg(
#hf_hub_id='timm/',
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt',
crop_pct=0.95,
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.)
),
"fastvit_mci1.apple_mclip": _cfg(
# hf_hub_id='timm/',
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt',
crop_pct=0.95,
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.)
),
"fastvit_mci2.apple_mclip": _cfg(
# hf_hub_id='timm/',
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt',
crop_pct=0.95,
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.)
),
})
def _checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
state_dict = state_dict.get('state_dict', state_dict)
if 'image_encoder.model.head.proj' in state_dict:
# remap MobileCLIP checkpoints
prefix = 'image_encoder.model.'
else:
prefix = ''
import re
import bisect
# find stage ends by locating downsample layers
stage_ends = []
for k, v in state_dict.items():
match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k)
if match:
stage_ends.append(int(match.group(2)))
stage_ends = list(sorted(set(stage_ends)))
out_dict = {}
for k, v in state_dict.items():
if prefix:
if prefix not in k:
continue
k = k.replace(prefix, '')
# remap renamed layers
k = k.replace('patch_embed', 'stem')
k = k.replace('rbr_conv', 'conv_kxk')
k = k.replace('rbr_scale', 'conv_scale')
k = k.replace('rbr_skip', 'identity')
k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet
k = k.replace('lkb_origin', 'large_conv')
k = k.replace('convffn', 'mlp')
k = k.replace('se.reduce', 'se.fc1')
k = k.replace('se.expand', 'se.fc2')
k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k)
if k.endswith('layer_scale'):
k = k.replace('layer_scale', 'layer_scale.gamma')
k = k.replace('dist_head', 'head_dist')
if k.startswith('head.'):
if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear):
# if CLIP projection, map to head.fc w/ bias = zeros
k = k.replace('head.proj', 'head.fc.weight')
v = v.T
out_dict['head.fc.bias'] = torch.zeros(v.shape[0])
else:
k = k.replace('head.', 'head.fc.')
# remap flat sequential network to stages
match = re.match(r'^network\.(\d+)', k)
stage_idx, net_idx = None, None
if match:
net_idx = int(match.group(1))
stage_idx = bisect.bisect_right(stage_ends, net_idx)
if stage_idx is not None:
net_prefix = f'network.{net_idx}'
stage_prefix = f'stages.{stage_idx}'
if net_prefix + '.proj' in k:
k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj')
elif net_prefix + '.pe' in k:
k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc')
else:
k = k.replace(net_prefix, stage_prefix + '.blocks')
out_dict[k] = v
return out_dict
def _create_fastvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
FastVit,
variant,
pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
@ -1419,3 +1586,45 @@ def fastvit_ma36(pretrained=False, **kwargs):
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
)
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_mci0(pretrained=False, **kwargs):
"""Instantiate MCi0 model variant."""
model_args = dict(
layers=(2, 6, 10, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3),
se_downsamples=(False, False, True, True),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_mci1(pretrained=False, **kwargs):
"""Instantiate MCi1 model variant."""
model_args = dict(
layers=(4, 12, 20, 4),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3),
se_downsamples=(False, False, True, True),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_mci2(pretrained=False, **kwargs):
"""Instantiate MCi2 model variant."""
model_args = dict(
layers=(4, 12, 24, 4),
embed_dims=(80, 160, 320, 640),
mlp_ratios=(3, 3, 3, 3),
se_downsamples=(False, False, True, True),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs))