mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add mobileclip fastvit model defs, support extra SE. Add forward_intermediates API to fastvit
This commit is contained in:
parent
5dce710101
commit
a503639bcc
@ -7,7 +7,7 @@
|
||||
#
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Tuple, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
|
||||
ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
@ -40,19 +41,19 @@ class MobileOneBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
use_act: bool = True,
|
||||
use_scale_branch: bool = True,
|
||||
num_conv_branches: int = 1,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
use_act: bool = True,
|
||||
use_scale_branch: bool = True,
|
||||
num_conv_branches: int = 1,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
) -> None:
|
||||
"""Construct a MobileOneBlock module.
|
||||
|
||||
@ -280,15 +281,16 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
group_size: int,
|
||||
small_kernel: Optional[int] = None,
|
||||
inference_mode: bool = False,
|
||||
act_layer: Optional[nn.Module] = None,
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
group_size: int,
|
||||
small_kernel: Optional[int] = None,
|
||||
use_se: bool = False,
|
||||
act_layer: Optional[nn.Module] = None,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
"""Construct a ReparamLargeKernelConv module.
|
||||
|
||||
@ -299,8 +301,8 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
stride: Stride size. Default: 1
|
||||
group_size: Group size. Default: 1
|
||||
small_kernel: Kernel size of small kernel conv branch.
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
act_layer: Activation module. Default: ``nn.GELU``
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
"""
|
||||
super(ReparamLargeKernelConv, self).__init__()
|
||||
self.stride = stride
|
||||
@ -342,6 +344,7 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
groups=self.groups,
|
||||
apply_act=False,
|
||||
)
|
||||
self.se = SqueezeExcite(out_chs, rd_ratio=0.25) if use_se else nn.Identity()
|
||||
# FIXME output of this act was not used in original impl, likely due to bug
|
||||
self.act = act_layer() if act_layer is not None else nn.Identity()
|
||||
|
||||
@ -352,6 +355,7 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
out = self.large_conv(x)
|
||||
if self.small_conv is not None:
|
||||
out = out + self.small_conv(x)
|
||||
out = self.se(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
@ -472,12 +476,12 @@ class Attention(nn.Module):
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build MHSA module that can handle 3D or 4D input tensors.
|
||||
|
||||
@ -535,14 +539,15 @@ class PatchEmbed(nn.Module):
|
||||
"""Convolutional patch embedding layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
in_chs: int,
|
||||
embed_dim: int,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
inference_mode: bool = False,
|
||||
self,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
in_chs: int,
|
||||
embed_dim: int,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
use_se: bool = False,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
"""Build patch embedding layer.
|
||||
|
||||
@ -562,14 +567,16 @@ class PatchEmbed(nn.Module):
|
||||
stride=stride,
|
||||
group_size=1,
|
||||
small_kernel=3,
|
||||
inference_mode=inference_mode,
|
||||
use_se=use_se,
|
||||
act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
|
||||
inference_mode=inference_mode,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_chs=embed_dim,
|
||||
out_chs=embed_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
use_se=False,
|
||||
act_layer=act_layer,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
@ -598,11 +605,11 @@ class RepMixer(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Module.
|
||||
|
||||
@ -648,7 +655,7 @@ class RepMixer(nn.Module):
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale = nn.Identity
|
||||
self.layer_scale = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.reparam_conv is not None:
|
||||
@ -706,12 +713,12 @@ class ConvMlp(nn.Module):
|
||||
"""Convolutional FFN Module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
hidden_channels: Optional[int] = None,
|
||||
out_chs: Optional[int] = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
self,
|
||||
in_chs: int,
|
||||
hidden_channels: Optional[int] = None,
|
||||
out_chs: Optional[int] = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build convolutional FFN module.
|
||||
|
||||
@ -764,11 +771,11 @@ class RepConditionalPosEnc(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
|
||||
inference_mode=False,
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
|
||||
inference_mode=False,
|
||||
) -> None:
|
||||
"""Build reparameterizable conditional positional encoding
|
||||
|
||||
@ -878,15 +885,15 @@ class RepMixerBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Block.
|
||||
|
||||
@ -936,14 +943,14 @@ class AttentionBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
):
|
||||
"""Build Attention Block.
|
||||
|
||||
@ -993,6 +1000,7 @@ class FastVitStage(nn.Module):
|
||||
depth: int,
|
||||
token_mixer_type: str,
|
||||
downsample: bool = True,
|
||||
se_downsample: bool = False,
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
pos_emb_layer: Optional[nn.Module] = None,
|
||||
@ -1030,6 +1038,7 @@ class FastVitStage(nn.Module):
|
||||
stride=down_stride,
|
||||
in_chs=dim,
|
||||
embed_dim=dim_out,
|
||||
use_se=se_downsample,
|
||||
act_layer=act_layer,
|
||||
lkc_use_act=lkc_use_act,
|
||||
inference_mode=inference_mode,
|
||||
@ -1090,29 +1099,30 @@ class FastVit(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
layers: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
|
||||
mlp_ratios: Tuple[float, ...] = (4,) * 4,
|
||||
downsamples: Tuple[bool, ...] = (False, True, True, True),
|
||||
repmixer_kernel_size: int = 3,
|
||||
num_classes: int = 1000,
|
||||
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
drop_rate: float = 0.0,
|
||||
proj_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
fork_feat: bool = False,
|
||||
cls_ratio: float = 2.0,
|
||||
global_pool: str = 'avg',
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
inference_mode: bool = False,
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
layers: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
|
||||
mlp_ratios: Tuple[float, ...] = (4,) * 4,
|
||||
downsamples: Tuple[bool, ...] = (False, True, True, True),
|
||||
se_downsamples: Tuple[bool, ...] = (False, False, False, False),
|
||||
repmixer_kernel_size: int = 3,
|
||||
num_classes: int = 1000,
|
||||
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
drop_rate: float = 0.0,
|
||||
proj_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
lkc_use_act: bool = False,
|
||||
fork_feat: bool = False,
|
||||
cls_ratio: float = 2.0,
|
||||
global_pool: str = 'avg',
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_classes = 0 if fork_feat else num_classes
|
||||
@ -1140,6 +1150,7 @@ class FastVit(nn.Module):
|
||||
dim_out=embed_dims[i],
|
||||
depth=layers[i],
|
||||
downsample=downsample,
|
||||
se_downsample=se_downsamples[i],
|
||||
down_patch_size=down_patch_size,
|
||||
down_stride=down_stride,
|
||||
pos_emb_layer=pos_embs[i],
|
||||
@ -1160,6 +1171,7 @@ class FastVit(nn.Module):
|
||||
scale *= 2
|
||||
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.num_stages = len(self.stages)
|
||||
self.num_features = prev_dim
|
||||
|
||||
# For segmentation and detection, extract intermediate output
|
||||
@ -1236,6 +1248,66 @@ class FastVit(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
last_idx = self.num_stages - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
feat_idx = 0
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# input embedding
|
||||
x = self.stem(x)
|
||||
@ -1297,8 +1369,7 @@ default_cfgs = generate_default_cfgs({
|
||||
|
||||
"fastvit_ma36.apple_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95
|
||||
),
|
||||
crop_pct=0.95),
|
||||
|
||||
"fastvit_t8.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
@ -1318,15 +1389,111 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95
|
||||
),
|
||||
|
||||
"fastvit_mci0.apple_mclip": _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt',
|
||||
crop_pct=0.95,
|
||||
num_classes=512, # CLIP proj dim
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.)
|
||||
),
|
||||
"fastvit_mci1.apple_mclip": _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt',
|
||||
crop_pct=0.95,
|
||||
num_classes=512, # CLIP proj dim
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.)
|
||||
),
|
||||
"fastvit_mci2.apple_mclip": _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt',
|
||||
crop_pct=0.95,
|
||||
num_classes=512, # CLIP proj dim
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.)
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def _checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
|
||||
return state_dict # non-original checkpoint, no remapping needed
|
||||
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
if 'image_encoder.model.head.proj' in state_dict:
|
||||
# remap MobileCLIP checkpoints
|
||||
prefix = 'image_encoder.model.'
|
||||
else:
|
||||
prefix = ''
|
||||
|
||||
import re
|
||||
import bisect
|
||||
|
||||
# find stage ends by locating downsample layers
|
||||
stage_ends = []
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k)
|
||||
if match:
|
||||
stage_ends.append(int(match.group(2)))
|
||||
stage_ends = list(sorted(set(stage_ends)))
|
||||
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if prefix:
|
||||
if prefix not in k:
|
||||
continue
|
||||
k = k.replace(prefix, '')
|
||||
|
||||
# remap renamed layers
|
||||
k = k.replace('patch_embed', 'stem')
|
||||
k = k.replace('rbr_conv', 'conv_kxk')
|
||||
k = k.replace('rbr_scale', 'conv_scale')
|
||||
k = k.replace('rbr_skip', 'identity')
|
||||
k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet
|
||||
k = k.replace('lkb_origin', 'large_conv')
|
||||
k = k.replace('convffn', 'mlp')
|
||||
k = k.replace('se.reduce', 'se.fc1')
|
||||
k = k.replace('se.expand', 'se.fc2')
|
||||
k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k)
|
||||
if k.endswith('layer_scale'):
|
||||
k = k.replace('layer_scale', 'layer_scale.gamma')
|
||||
k = k.replace('dist_head', 'head_dist')
|
||||
if k.startswith('head.'):
|
||||
if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear):
|
||||
# if CLIP projection, map to head.fc w/ bias = zeros
|
||||
k = k.replace('head.proj', 'head.fc.weight')
|
||||
v = v.T
|
||||
out_dict['head.fc.bias'] = torch.zeros(v.shape[0])
|
||||
else:
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
|
||||
# remap flat sequential network to stages
|
||||
match = re.match(r'^network\.(\d+)', k)
|
||||
stage_idx, net_idx = None, None
|
||||
if match:
|
||||
net_idx = int(match.group(1))
|
||||
stage_idx = bisect.bisect_right(stage_ends, net_idx)
|
||||
if stage_idx is not None:
|
||||
net_prefix = f'network.{net_idx}'
|
||||
stage_prefix = f'stages.{stage_idx}'
|
||||
if net_prefix + '.proj' in k:
|
||||
k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj')
|
||||
elif net_prefix + '.pe' in k:
|
||||
k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc')
|
||||
else:
|
||||
k = k.replace(net_prefix, stage_prefix + '.blocks')
|
||||
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_fastvit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
|
||||
model = build_model_with_cfg(
|
||||
FastVit,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
@ -1419,3 +1586,45 @@ def fastvit_ma36(pretrained=False, **kwargs):
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
|
||||
)
|
||||
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_mci0(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi0 model variant."""
|
||||
model_args = dict(
|
||||
layers=(2, 6, 10, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
se_downsamples=(False, False, True, True),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_mci1(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi1 model variant."""
|
||||
model_args = dict(
|
||||
layers=(4, 12, 20, 4),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
se_downsamples=(False, False, True, True),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_mci2(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi2 model variant."""
|
||||
model_args = dict(
|
||||
layers=(4, 12, 24, 4),
|
||||
embed_dims=(80, 160, 320, 640),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
se_downsamples=(False, False, True, True),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs))
|
Loading…
x
Reference in New Issue
Block a user