More fastvit & mobileone updates, ready for weight upload
parent
8474508d07
commit
8470eb1cb5
|
@ -20,6 +20,7 @@ from .efficientnet import *
|
|||
from .efficientvit_mit import *
|
||||
from .efficientvit_msra import *
|
||||
from .eva import *
|
||||
from .fastvit import *
|
||||
from .focalnet import *
|
||||
from .gcvit import *
|
||||
from .ghostnet import *
|
||||
|
|
|
@ -589,7 +589,7 @@ class MobileOneBlock(nn.Module):
|
|||
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
|
||||
|
||||
# Re-parameterizable conv branches
|
||||
convs = list()
|
||||
convs = []
|
||||
for _ in range(self.num_conv_branches):
|
||||
convs.append(layers.conv_norm_act(
|
||||
in_chs, out_chs, kernel_size=kernel_size,
|
||||
|
@ -624,8 +624,8 @@ class MobileOneBlock(nn.Module):
|
|||
|
||||
# Other branches
|
||||
out = scale_out + identity_out
|
||||
for ix in range(self.num_conv_branches):
|
||||
out += self.conv_kxk[ix](x)
|
||||
for ck in self.conv_kxk:
|
||||
out += ck(x)
|
||||
|
||||
return self.act(self.attn(out))
|
||||
|
||||
|
@ -1643,11 +1643,15 @@ model_cfgs = dict(
|
|||
),
|
||||
)
|
||||
|
||||
# FIXME temporary for mobileone remap
|
||||
from .fastvit import checkpoint_filter_fn
|
||||
|
||||
|
||||
def _create_byobnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ByobNet, variant, pretrained,
|
||||
model_cfg=model_cfgs[variant],
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
|
@ -1805,6 +1809,27 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
|
||||
|
||||
'mobileone_s0': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar',
|
||||
crop_pct=0.875,
|
||||
),
|
||||
'mobileone_s1': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
'mobileone_s2': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
'mobileone_s3': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
'mobileone_s4': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,8 @@ import torch.nn as nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn
|
||||
from ._registry import register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
|
||||
def num_groups(group_size, channels):
|
||||
|
@ -58,7 +59,7 @@ class MobileOneBlock(nn.Module):
|
|||
kernel_size: Size of the convolution kernel.
|
||||
stride: Stride size.
|
||||
dilation: Kernel dilation factor.
|
||||
groups: Group number.
|
||||
group_size: Convolution group size.
|
||||
inference_mode: If True, instantiates model in inference mode.
|
||||
use_se: Whether to use SE-ReLU activations.
|
||||
use_act: Whether to use activation. Default: ``True``
|
||||
|
@ -76,7 +77,7 @@ class MobileOneBlock(nn.Module):
|
|||
self.num_conv_branches = num_conv_branches
|
||||
|
||||
# Check if SE-ReLU is requested
|
||||
self.se = SqueezeExcite(out_chs) if use_se else nn.Identity()
|
||||
self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity()
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = create_conv2d(
|
||||
|
@ -100,17 +101,16 @@ class MobileOneBlock(nn.Module):
|
|||
|
||||
# Re-parameterizable conv branches
|
||||
if num_conv_branches > 0:
|
||||
rbr_conv = list()
|
||||
for _ in range(self.num_conv_branches):
|
||||
rbr_conv.append(ConvNormAct(
|
||||
self.rbr_conv = nn.ModuleList([
|
||||
ConvNormAct(
|
||||
self.in_chs,
|
||||
self.out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
groups=self.groups,
|
||||
apply_act=False,
|
||||
))
|
||||
self.rbr_conv = nn.ModuleList(rbr_conv)
|
||||
) for _ in range(self.num_conv_branches)
|
||||
])
|
||||
else:
|
||||
self.rbr_conv = None
|
||||
|
||||
|
@ -148,8 +148,8 @@ class MobileOneBlock(nn.Module):
|
|||
# Other branches
|
||||
out = scale_out + identity_out
|
||||
if self.rbr_conv is not None:
|
||||
for ix in range(self.num_conv_branches):
|
||||
out += self.rbr_conv[ix](x)
|
||||
for rc in self.rbr_conv:
|
||||
out += rc(x)
|
||||
|
||||
return self.act(self.se(out))
|
||||
|
||||
|
@ -159,8 +159,9 @@ class MobileOneBlock(nn.Module):
|
|||
architecture used at training time to obtain a plain CNN-like structure
|
||||
for inference.
|
||||
"""
|
||||
if self.inference_mode:
|
||||
if self.reparam_conv is not None:
|
||||
return
|
||||
|
||||
kernel, bias = self._get_kernel_bias()
|
||||
self.reparam_conv = create_conv2d(
|
||||
in_channels=self.in_chs,
|
||||
|
@ -177,6 +178,7 @@ class MobileOneBlock(nn.Module):
|
|||
# Delete un-used branches
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
|
||||
self.__delattr__("rbr_conv")
|
||||
self.__delattr__("rbr_scale")
|
||||
if hasattr(self, "rbr_skip"):
|
||||
|
@ -339,7 +341,6 @@ class ReparamLargeKernelConv(nn.Module):
|
|||
self.act = act_layer() if act_layer is not None else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
if self.lkb_reparam is not None:
|
||||
out = self.lkb_reparam(x)
|
||||
else:
|
||||
|
@ -1076,14 +1077,17 @@ def basic_blocks(
|
|||
|
||||
|
||||
class FastVit(nn.Module):
|
||||
fork_feat: torch.jit.Final[bool]
|
||||
|
||||
"""
|
||||
This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers,
|
||||
token_mixers: Tuple[str, ...],
|
||||
in_chans=3,
|
||||
layers=(2, 2, 6, 2),
|
||||
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
embed_dims=None,
|
||||
mlp_ratios=None,
|
||||
downsamples=None,
|
||||
|
@ -1099,24 +1103,19 @@ class FastVit(nn.Module):
|
|||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
fork_feat=False,
|
||||
init_cfg=None,
|
||||
pretrained=None,
|
||||
cls_ratio=2.0,
|
||||
inference_mode=False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
if not fork_feat:
|
||||
self.num_classes = num_classes
|
||||
self.num_classes = 0 if fork_feat else num_classes
|
||||
self.fork_feat = fork_feat
|
||||
|
||||
if pos_embs is None:
|
||||
pos_embs = [None] * len(layers)
|
||||
|
||||
# Convolutional stem
|
||||
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode)
|
||||
self.patch_embed = convolutional_stem(
|
||||
in_chans, embed_dims[0], inference_mode)
|
||||
|
||||
# Build the main stages of the network architecture
|
||||
network = []
|
||||
|
@ -1192,14 +1191,9 @@ class FastVit(nn.Module):
|
|||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.apply(self.cls_init_weights)
|
||||
self.init_cfg = copy.deepcopy(init_cfg)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
# load pre-trained model
|
||||
if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
|
||||
self.init_weights()
|
||||
|
||||
def cls_init_weights(self, m: nn.Module) -> None:
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
"""Init. for classification"""
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
|
@ -1224,10 +1218,11 @@ class FastVit(nn.Module):
|
|||
outs = []
|
||||
for idx, block in enumerate(self.network):
|
||||
x = block(x)
|
||||
if self.fork_feat and idx in self.out_indices:
|
||||
norm_layer = getattr(self, f"norm{idx}")
|
||||
x_out = norm_layer(x)
|
||||
outs.append(x_out)
|
||||
if self.fork_feat:
|
||||
if idx in self.out_indices:
|
||||
norm_layer = getattr(self, f"norm{idx}")
|
||||
x_out = norm_layer(x)
|
||||
outs.append(x_out)
|
||||
if self.fork_feat:
|
||||
# output the features of four stages for dense prediction
|
||||
return outs
|
||||
|
@ -1256,7 +1251,7 @@ def _cfg(url="", **kwargs):
|
|||
"num_classes": 1000,
|
||||
"input_size": (3, 256, 256),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.95,
|
||||
"crop_pct": 0.9,
|
||||
"interpolation": "bicubic",
|
||||
"mean": IMAGENET_DEFAULT_MEAN,
|
||||
"std": IMAGENET_DEFAULT_STD,
|
||||
|
@ -1265,174 +1260,166 @@ def _cfg(url="", **kwargs):
|
|||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
"fastvit_t": _cfg(crop_pct=0.9),
|
||||
"fastvit_s": _cfg(crop_pct=0.9),
|
||||
"fastvit_m": _cfg(crop_pct=0.95),
|
||||
'fastvit_t8': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar')
|
||||
}
|
||||
default_cfgs = generate_default_cfgs({
|
||||
"fastvit_t8.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar'
|
||||
),
|
||||
"fastvit_t12.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t12.pth.tar'
|
||||
),
|
||||
|
||||
"fastvit_s12.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_s12.pth.tar'),
|
||||
"fastvit_sa12.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa12.pth.tar'),
|
||||
"fastvit_sa24.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa24.pth.tar'),
|
||||
"fastvit_sa36.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa36.pth.tar'),
|
||||
|
||||
"fastvit_ma36.apple_in1k": _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_ma36.pth.tar',
|
||||
crop_pct=0.95
|
||||
),
|
||||
|
||||
# "fastvit_t8.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t8.pth.tar'
|
||||
# ),
|
||||
# "fastvit_t12.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t12.pth.tar'
|
||||
# ),
|
||||
#
|
||||
# "fastvit_s12.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12.pth.tar'),
|
||||
# "fastvit_sa12.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa12.pth.tar'),
|
||||
# "fastvit_sa24.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa24.pth.tar'),
|
||||
# "fastvit_sa36.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa36.pth.tar'),
|
||||
#
|
||||
# "fastvit_ma36.apple_dist_in1k": _cfg(
|
||||
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_ma36.pth.tar',
|
||||
# crop_pct=0.95
|
||||
# ),
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
# FIXME temporary for remapping
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
msd = model.state_dict()
|
||||
out_dict = {}
|
||||
for ka, kb, va, vb in zip(msd.keys(), state_dict.keys(), msd.values(), state_dict.values()):
|
||||
if va.ndim == 4 and vb.ndim == 2:
|
||||
vb = vb[:, :, None, None]
|
||||
out_dict[ka] = vb
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_fastvit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
|
||||
model = build_model_with_cfg(
|
||||
FastVit,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_t8(pretrained=False, **kwargs):
|
||||
"""Instantiate FastViT-T8 model variant."""
|
||||
layers = [2, 2, 4, 2]
|
||||
embed_dims = [48, 96, 192, 384]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
|
||||
model = FastVit(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
**kwargs,
|
||||
model_args = dict(
|
||||
layers=(2, 2, 4, 2),
|
||||
embed_dims=(48, 96, 192, 384),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
downsamples=(True, True, True, True),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_t"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_t12(pretrained=False, **kwargs):
|
||||
"""Instantiate FastViT-T12 model variant."""
|
||||
layers = [2, 2, 6, 2]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
|
||||
model = FastVit(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
**kwargs,
|
||||
model_args = dict(
|
||||
layers=(2, 2, 6, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
downsamples=(True, True, True, True),
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_t"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_s12(pretrained=False, **kwargs):
|
||||
"""Instantiate FastViT-S12 model variant."""
|
||||
layers = [2, 2, 6, 2]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [4, 4, 4, 4]
|
||||
downsamples = [True, True, True, True]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
|
||||
model = FastVit(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
**kwargs,
|
||||
model_args = dict(
|
||||
layers=(2, 2, 6, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_sa12(pretrained=False, **kwargs):
|
||||
"""Instantiate FastViT-SA12 model variant."""
|
||||
layers = [2, 2, 6, 2]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [4, 4, 4, 4]
|
||||
downsamples = [True, True, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastVit(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
**kwargs,
|
||||
model_args = dict(
|
||||
layers=(2, 2, 6, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_sa24(pretrained=False, **kwargs):
|
||||
"""Instantiate FastViT-SA24 model variant."""
|
||||
layers = [4, 4, 12, 4]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [4, 4, 4, 4]
|
||||
downsamples = [True, True, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastVit(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
**kwargs,
|
||||
model_args = dict(
|
||||
layers=(4, 4, 12, 4),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_sa36(pretrained=False, **kwargs):
|
||||
"""Instantiate FastViT-SA36 model variant."""
|
||||
layers = [6, 6, 18, 6]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [4, 4, 4, 4]
|
||||
downsamples = [True, True, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastVit(
|
||||
layers,
|
||||
embed_dims=embed_dims,
|
||||
token_mixers=token_mixers,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
layer_scale_init_value=1e-6,
|
||||
**kwargs,
|
||||
model_args = dict(
|
||||
layers=(6, 6, 18, 6),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_m"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
|
||||
return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
@register_model
|
||||
def fastvit_ma36(pretrained=False, **kwargs):
|
||||
"""Instantiate FastViT-MA36 model variant."""
|
||||
layers = [6, 6, 18, 6]
|
||||
embed_dims = [76, 152, 304, 608]
|
||||
mlp_ratios = [4, 4, 4, 4]
|
||||
downsamples = [True, True, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastVit(
|
||||
layers,
|
||||
embed_dims=embed_dims,
|
||||
token_mixers=token_mixers,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
layer_scale_init_value=1e-6,
|
||||
**kwargs,
|
||||
model_args = dict(
|
||||
layers=(6, 6, 18, 6),
|
||||
embed_dims=(76, 152, 304, 608),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_m"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
|
Loading…
Reference in New Issue