More fastvit & mobileone updates, ready for weight upload

This commit is contained in:
Ross Wightman 2023-08-22 14:42:43 -07:00 committed by Ross Wightman
parent 8474508d07
commit 8470eb1cb5
3 changed files with 180 additions and 167 deletions

View File

@ -20,6 +20,7 @@ from .efficientnet import *
from .efficientvit_mit import * from .efficientvit_mit import *
from .efficientvit_msra import * from .efficientvit_msra import *
from .eva import * from .eva import *
from .fastvit import *
from .focalnet import * from .focalnet import *
from .gcvit import * from .gcvit import *
from .ghostnet import * from .ghostnet import *

View File

@ -589,7 +589,7 @@ class MobileOneBlock(nn.Module):
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
# Re-parameterizable conv branches # Re-parameterizable conv branches
convs = list() convs = []
for _ in range(self.num_conv_branches): for _ in range(self.num_conv_branches):
convs.append(layers.conv_norm_act( convs.append(layers.conv_norm_act(
in_chs, out_chs, kernel_size=kernel_size, in_chs, out_chs, kernel_size=kernel_size,
@ -624,8 +624,8 @@ class MobileOneBlock(nn.Module):
# Other branches # Other branches
out = scale_out + identity_out out = scale_out + identity_out
for ix in range(self.num_conv_branches): for ck in self.conv_kxk:
out += self.conv_kxk[ix](x) out += ck(x)
return self.act(self.attn(out)) return self.act(self.attn(out))
@ -1643,11 +1643,15 @@ model_cfgs = dict(
), ),
) )
# FIXME temporary for mobileone remap
from .fastvit import checkpoint_filter_fn
def _create_byobnet(variant, pretrained=False, **kwargs): def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
ByobNet, variant, pretrained, ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant], model_cfg=model_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**kwargs) **kwargs)
@ -1805,6 +1809,27 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
'mobileone_s0': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar',
crop_pct=0.875,
),
'mobileone_s1': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar',
crop_pct=0.9,
),
'mobileone_s2': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar',
crop_pct=0.9,
),
'mobileone_s3': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar',
crop_pct=0.9,
),
'mobileone_s4': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar',
crop_pct=0.9,
),
}) })

View File

@ -13,7 +13,8 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn
from ._registry import register_model from ._builder import build_model_with_cfg
from ._registry import register_model, generate_default_cfgs
def num_groups(group_size, channels): def num_groups(group_size, channels):
@ -58,7 +59,7 @@ class MobileOneBlock(nn.Module):
kernel_size: Size of the convolution kernel. kernel_size: Size of the convolution kernel.
stride: Stride size. stride: Stride size.
dilation: Kernel dilation factor. dilation: Kernel dilation factor.
groups: Group number. group_size: Convolution group size.
inference_mode: If True, instantiates model in inference mode. inference_mode: If True, instantiates model in inference mode.
use_se: Whether to use SE-ReLU activations. use_se: Whether to use SE-ReLU activations.
use_act: Whether to use activation. Default: ``True`` use_act: Whether to use activation. Default: ``True``
@ -76,7 +77,7 @@ class MobileOneBlock(nn.Module):
self.num_conv_branches = num_conv_branches self.num_conv_branches = num_conv_branches
# Check if SE-ReLU is requested # Check if SE-ReLU is requested
self.se = SqueezeExcite(out_chs) if use_se else nn.Identity() self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity()
if inference_mode: if inference_mode:
self.reparam_conv = create_conv2d( self.reparam_conv = create_conv2d(
@ -100,17 +101,16 @@ class MobileOneBlock(nn.Module):
# Re-parameterizable conv branches # Re-parameterizable conv branches
if num_conv_branches > 0: if num_conv_branches > 0:
rbr_conv = list() self.rbr_conv = nn.ModuleList([
for _ in range(self.num_conv_branches): ConvNormAct(
rbr_conv.append(ConvNormAct(
self.in_chs, self.in_chs,
self.out_chs, self.out_chs,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=self.stride, stride=self.stride,
groups=self.groups, groups=self.groups,
apply_act=False, apply_act=False,
)) ) for _ in range(self.num_conv_branches)
self.rbr_conv = nn.ModuleList(rbr_conv) ])
else: else:
self.rbr_conv = None self.rbr_conv = None
@ -148,8 +148,8 @@ class MobileOneBlock(nn.Module):
# Other branches # Other branches
out = scale_out + identity_out out = scale_out + identity_out
if self.rbr_conv is not None: if self.rbr_conv is not None:
for ix in range(self.num_conv_branches): for rc in self.rbr_conv:
out += self.rbr_conv[ix](x) out += rc(x)
return self.act(self.se(out)) return self.act(self.se(out))
@ -159,8 +159,9 @@ class MobileOneBlock(nn.Module):
architecture used at training time to obtain a plain CNN-like structure architecture used at training time to obtain a plain CNN-like structure
for inference. for inference.
""" """
if self.inference_mode: if self.reparam_conv is not None:
return return
kernel, bias = self._get_kernel_bias() kernel, bias = self._get_kernel_bias()
self.reparam_conv = create_conv2d( self.reparam_conv = create_conv2d(
in_channels=self.in_chs, in_channels=self.in_chs,
@ -177,6 +178,7 @@ class MobileOneBlock(nn.Module):
# Delete un-used branches # Delete un-used branches
for para in self.parameters(): for para in self.parameters():
para.detach_() para.detach_()
self.__delattr__("rbr_conv") self.__delattr__("rbr_conv")
self.__delattr__("rbr_scale") self.__delattr__("rbr_scale")
if hasattr(self, "rbr_skip"): if hasattr(self, "rbr_skip"):
@ -339,7 +341,6 @@ class ReparamLargeKernelConv(nn.Module):
self.act = act_layer() if act_layer is not None else nn.Identity() self.act = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
if self.lkb_reparam is not None: if self.lkb_reparam is not None:
out = self.lkb_reparam(x) out = self.lkb_reparam(x)
else: else:
@ -1076,14 +1077,17 @@ def basic_blocks(
class FastVit(nn.Module): class FastVit(nn.Module):
fork_feat: torch.jit.Final[bool]
""" """
This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_ This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
""" """
def __init__( def __init__(
self, self,
layers, in_chans=3,
token_mixers: Tuple[str, ...], layers=(2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims=None, embed_dims=None,
mlp_ratios=None, mlp_ratios=None,
downsamples=None, downsamples=None,
@ -1099,24 +1103,19 @@ class FastVit(nn.Module):
use_layer_scale=True, use_layer_scale=True,
layer_scale_init_value=1e-5, layer_scale_init_value=1e-5,
fork_feat=False, fork_feat=False,
init_cfg=None,
pretrained=None,
cls_ratio=2.0, cls_ratio=2.0,
inference_mode=False, inference_mode=False,
**kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_classes = 0 if fork_feat else num_classes
if not fork_feat:
self.num_classes = num_classes
self.fork_feat = fork_feat self.fork_feat = fork_feat
if pos_embs is None: if pos_embs is None:
pos_embs = [None] * len(layers) pos_embs = [None] * len(layers)
# Convolutional stem # Convolutional stem
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode) self.patch_embed = convolutional_stem(
in_chans, embed_dims[0], inference_mode)
# Build the main stages of the network architecture # Build the main stages of the network architecture
network = [] network = []
@ -1192,14 +1191,9 @@ class FastVit(nn.Module):
else nn.Identity() else nn.Identity()
) )
self.apply(self.cls_init_weights) self.apply(self._init_weights)
self.init_cfg = copy.deepcopy(init_cfg)
# load pre-trained model def _init_weights(self, m: nn.Module) -> None:
if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
self.init_weights()
def cls_init_weights(self, m: nn.Module) -> None:
"""Init. for classification""" """Init. for classification"""
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02) trunc_normal_(m.weight, std=0.02)
@ -1224,10 +1218,11 @@ class FastVit(nn.Module):
outs = [] outs = []
for idx, block in enumerate(self.network): for idx, block in enumerate(self.network):
x = block(x) x = block(x)
if self.fork_feat and idx in self.out_indices: if self.fork_feat:
norm_layer = getattr(self, f"norm{idx}") if idx in self.out_indices:
x_out = norm_layer(x) norm_layer = getattr(self, f"norm{idx}")
outs.append(x_out) x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat: if self.fork_feat:
# output the features of four stages for dense prediction # output the features of four stages for dense prediction
return outs return outs
@ -1256,7 +1251,7 @@ def _cfg(url="", **kwargs):
"num_classes": 1000, "num_classes": 1000,
"input_size": (3, 256, 256), "input_size": (3, 256, 256),
"pool_size": None, "pool_size": None,
"crop_pct": 0.95, "crop_pct": 0.9,
"interpolation": "bicubic", "interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN, "mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD, "std": IMAGENET_DEFAULT_STD,
@ -1265,174 +1260,166 @@ def _cfg(url="", **kwargs):
} }
default_cfgs = { default_cfgs = generate_default_cfgs({
"fastvit_t": _cfg(crop_pct=0.9), "fastvit_t8.apple_in1k": _cfg(
"fastvit_s": _cfg(crop_pct=0.9), url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar'
"fastvit_m": _cfg(crop_pct=0.95), ),
'fastvit_t8': _cfg( "fastvit_t12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar') url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t12.pth.tar'
} ),
"fastvit_s12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_s12.pth.tar'),
"fastvit_sa12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa12.pth.tar'),
"fastvit_sa24.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa24.pth.tar'),
"fastvit_sa36.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa36.pth.tar'),
"fastvit_ma36.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_ma36.pth.tar',
crop_pct=0.95
),
# "fastvit_t8.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t8.pth.tar'
# ),
# "fastvit_t12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t12.pth.tar'
# ),
#
# "fastvit_s12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12.pth.tar'),
# "fastvit_sa12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa12.pth.tar'),
# "fastvit_sa24.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa24.pth.tar'),
# "fastvit_sa36.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa36.pth.tar'),
#
# "fastvit_ma36.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_ma36.pth.tar',
# crop_pct=0.95
# ),
})
def checkpoint_filter_fn(state_dict, model):
# FIXME temporary for remapping
state_dict = state_dict.get('state_dict', state_dict)
msd = model.state_dict()
out_dict = {}
for ka, kb, va, vb in zip(msd.keys(), state_dict.keys(), msd.values(), state_dict.values()):
if va.ndim == 4 and vb.ndim == 2:
vb = vb[:, :, None, None]
out_dict[ka] = vb
return out_dict
def _create_fastvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
FastVit,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
return model
@register_model @register_model
def fastvit_t8(pretrained=False, **kwargs): def fastvit_t8(pretrained=False, **kwargs):
"""Instantiate FastViT-T8 model variant.""" """Instantiate FastViT-T8 model variant."""
layers = [2, 2, 4, 2] model_args = dict(
embed_dims = [48, 96, 192, 384] layers=(2, 2, 4, 2),
mlp_ratios = [3, 3, 3, 3] embed_dims=(48, 96, 192, 384),
downsamples = [True, True, True, True] mlp_ratios=(3, 3, 3, 3),
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") downsamples=(True, True, True, True),
model = FastVit( token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
) )
model.default_cfg = default_cfgs["fastvit_t"] return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
if pretrained:
raise ValueError("Functionality not implemented.")
return model
@register_model @register_model
def fastvit_t12(pretrained=False, **kwargs): def fastvit_t12(pretrained=False, **kwargs):
"""Instantiate FastViT-T12 model variant.""" """Instantiate FastViT-T12 model variant."""
layers = [2, 2, 6, 2] model_args = dict(
embed_dims = [64, 128, 256, 512] layers=(2, 2, 6, 2),
mlp_ratios = [3, 3, 3, 3] embed_dims=(64, 128, 256, 512),
downsamples = [True, True, True, True] mlp_ratios=(3, 3, 3, 3),
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") downsamples=(True, True, True, True),
model = FastVit( token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer"),
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
) )
model.default_cfg = default_cfgs["fastvit_t"] return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
if pretrained:
raise ValueError("Functionality not implemented.")
return model
@register_model @register_model
def fastvit_s12(pretrained=False, **kwargs): def fastvit_s12(pretrained=False, **kwargs):
"""Instantiate FastViT-S12 model variant.""" """Instantiate FastViT-S12 model variant."""
layers = [2, 2, 6, 2] model_args = dict(
embed_dims = [64, 128, 256, 512] layers=(2, 2, 6, 2),
mlp_ratios = [4, 4, 4, 4] embed_dims=(64, 128, 256, 512),
downsamples = [True, True, True, True] mlp_ratios=(4, 4, 4, 4),
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") downsamples=(True, True, True, True),
model = FastVit( token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
) )
model.default_cfg = default_cfgs["fastvit_s"] return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
if pretrained:
raise ValueError("Functionality not implemented.")
return model
@register_model @register_model
def fastvit_sa12(pretrained=False, **kwargs): def fastvit_sa12(pretrained=False, **kwargs):
"""Instantiate FastViT-SA12 model variant.""" """Instantiate FastViT-SA12 model variant."""
layers = [2, 2, 6, 2] model_args = dict(
embed_dims = [64, 128, 256, 512] layers=(2, 2, 6, 2),
mlp_ratios = [4, 4, 4, 4] embed_dims=(64, 128, 256, 512),
downsamples = [True, True, True, True] mlp_ratios=(4, 4, 4, 4),
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] downsamples=(True, True, True, True),
token_mixers = ("repmixer", "repmixer", "repmixer", "attention") pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
model = FastVit( token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
) )
model.default_cfg = default_cfgs["fastvit_s"] return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
if pretrained:
raise ValueError("Functionality not implemented.")
return model
@register_model @register_model
def fastvit_sa24(pretrained=False, **kwargs): def fastvit_sa24(pretrained=False, **kwargs):
"""Instantiate FastViT-SA24 model variant.""" """Instantiate FastViT-SA24 model variant."""
layers = [4, 4, 12, 4] model_args = dict(
embed_dims = [64, 128, 256, 512] layers=(4, 4, 12, 4),
mlp_ratios = [4, 4, 4, 4] embed_dims=(64, 128, 256, 512),
downsamples = [True, True, True, True] mlp_ratios=(4, 4, 4, 4),
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] downsamples=(True, True, True, True),
token_mixers = ("repmixer", "repmixer", "repmixer", "attention") pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
model = FastVit( token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
) )
model.default_cfg = default_cfgs["fastvit_s"] return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
if pretrained:
raise ValueError("Functionality not implemented.")
return model
@register_model @register_model
def fastvit_sa36(pretrained=False, **kwargs): def fastvit_sa36(pretrained=False, **kwargs):
"""Instantiate FastViT-SA36 model variant.""" """Instantiate FastViT-SA36 model variant."""
layers = [6, 6, 18, 6] model_args = dict(
embed_dims = [64, 128, 256, 512] layers=(6, 6, 18, 6),
mlp_ratios = [4, 4, 4, 4] embed_dims=(64, 128, 256, 512),
downsamples = [True, True, True, True] mlp_ratios=(4, 4, 4, 4),
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] downsamples=(True, True, True, True),
token_mixers = ("repmixer", "repmixer", "repmixer", "attention") pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
model = FastVit( token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
layers,
embed_dims=embed_dims,
token_mixers=token_mixers,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
layer_scale_init_value=1e-6,
**kwargs,
) )
model.default_cfg = default_cfgs["fastvit_m"] return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
if pretrained:
raise ValueError("Functionality not implemented.")
return model
@register_model @register_model
def fastvit_ma36(pretrained=False, **kwargs): def fastvit_ma36(pretrained=False, **kwargs):
"""Instantiate FastViT-MA36 model variant.""" """Instantiate FastViT-MA36 model variant."""
layers = [6, 6, 18, 6] model_args = dict(
embed_dims = [76, 152, 304, 608] layers=(6, 6, 18, 6),
mlp_ratios = [4, 4, 4, 4] embed_dims=(76, 152, 304, 608),
downsamples = [True, True, True, True] mlp_ratios=(4, 4, 4, 4),
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] downsamples=(True, True, True, True),
token_mixers = ("repmixer", "repmixer", "repmixer", "attention") pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
model = FastVit( token_mixers=("repmixer", "repmixer", "repmixer", "attention")
layers,
embed_dims=embed_dims,
token_mixers=token_mixers,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
layer_scale_init_value=1e-6,
**kwargs,
) )
model.default_cfg = default_cfgs["fastvit_m"] return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
if pretrained:
raise ValueError("Functionality not implemented.")
return model