Fixup ViTamin, add hub weight reference

This commit is contained in:
Ross Wightman 2024-06-03 17:14:03 -07:00
parent b2c0aeb0ec
commit 1b66ec7cf3
3 changed files with 219 additions and 190 deletions

View File

@ -409,6 +409,7 @@ class VisionTransformer(nn.Module):
qk_norm: bool = False,
init_values: Optional[float] = None,
class_token: bool = True,
pos_embed: str = 'learn',
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
@ -460,6 +461,7 @@ class VisionTransformer(nn.Module):
super().__init__()
assert global_pool in ('', 'avg', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU
@ -494,7 +496,10 @@ class VisionTransformer(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
if not pos_embed or pos_embed == 'none':
self.pos_embed = None
else:
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
@ -556,7 +561,8 @@ class VisionTransformer(nn.Module):
def init_weights(self, mode: str = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
@ -583,6 +589,8 @@ class VisionTransformer(nn.Module):
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
if hasattr(self.patch_embed, 'set_grad_checkpointing'):
self.patch_embed.set_grad_checkpointing(enable)
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
@ -600,6 +608,9 @@ class VisionTransformer(nn.Module):
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.pos_embed is None:
return x
if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
@ -1066,10 +1077,13 @@ def checkpoint_filter_fn(
# IJEPA, vit in an 'encoder' submodule
state_dict = state_dict['encoder']
prefix = 'module.'
elif 'visual.trunk.pos_embed' in state_dict:
elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
# OpenCLIP model with timm vision encoder
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
prefix = 'visual.trunk.'
if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
if prefix:
# filter on & remove prefix string from keys

View File

@ -38,14 +38,15 @@ class HybridEmbed(nn.Module):
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
feature_ratio=None,
in_chans=3,
embed_dim=768,
bias=True,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias: bool = True,
proj: bool = True,
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
@ -95,7 +96,18 @@ class HybridEmbed(nn.Module):
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
if proj:
self.proj = nn.Conv2d(
feature_dim,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
)
else:
assert feature_dim == embed_dim,\
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
self.proj = nn.Identity()
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
total_reduction = (
@ -116,6 +128,13 @@ class HybridEmbed(nn.Module):
else:
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
@ -157,6 +176,13 @@ class HybridEmbedWithSize(nn.Module):
bias=bias,
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.backbone(x)
if isinstance(x, (list, tuple)):

View File

@ -19,29 +19,22 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
"""
import math
from dataclasses import dataclass, field
from functools import partial
from typing import List, Tuple
from dataclasses import dataclass, replace, field
from typing import Callable, Optional, Union, Tuple, List, Sequence
import math, time
from torch.jit import Final
from typing import Optional, Union, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch.utils.checkpoint import checkpoint
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
from timm.layers import to_2tuple
from timm.layers import DropPath
from timm.layers.norm_act import _create_act
from timm.models._manipulate import named_apply, checkpoint_seq
from timm.models._builder import build_model_with_cfg
from timm.models._registry import register_model
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
from timm.models.vision_transformer_hybrid import HybridEmbed
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
make_divisible, DropPath
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
from .vision_transformer_hybrid import HybridEmbed
@dataclass
@ -90,24 +83,19 @@ class Stem(nn.Module):
bias: bool = True,
):
super().__init__()
self.grad_checkpointing=False
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
self.out_chs = out_chs
self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
self.norm1 = norm_act_layer(out_chs)
self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
named_apply(_init_conv, self)
def forward(self, x):
if self.grad_checkpointing:
x = checkpoint(self.conv1, x)
x = self.norm1(x)
x = checkpoint(self.conv2, x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
return x
@ -145,8 +133,9 @@ class StridedConv(nn.Module):
embed_dim=768
):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
self.norm = norm_layer(in_chans) # affine over C
def forward(self, x):
@ -185,10 +174,10 @@ class MbConvLNBlock(nn.Module):
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
self.down = nn.Identity()
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
self.act1 = _create_act(act_layer, inplace=True)
self.act2 = _create_act(act_layer, inplace=True)
self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
self.act1 = create_act_layer(act_layer, inplace=True)
self.conv2_kxk = create_conv2d(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
self.act2 = create_act_layer(act_layer, inplace=True)
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -228,58 +217,57 @@ class MbConvStages(nn.Module):
):
super().__init__()
self.grad_checkpointing = False
self.stem = Stem(
in_chs=in_chans,
out_chs=cfg.stem_width,
)
stages = []
self.num_stages = len(cfg.embed_dim)
for s, dim in enumerate(cfg.embed_dim[:2]): # stage
blocks = []
stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
for d in range(cfg.depths[s]):
blocks += [MbConvLNBlock(
in_chs = stage_in_chs if d==0 else dim,
out_chs = dim,
stride = 2 if d == 0 else 1,
# cfg = cfg.conv_cfg,
)]
blocks = nn.Sequential(*blocks)
stages += [blocks]
blocks = [
MbConvLNBlock(
in_chs = stage_in_chs if d==0 else dim,
out_chs = dim,
stride = 2 if d == 0 else 1,
)
for d in range(cfg.depths[s])
]
stages += [nn.Sequential(*blocks)]
self.stages = nn.Sequential(*stages)
self.stages = nn.ModuleList(stages)
self.pool = StridedConv(
stride=2,
in_chans=cfg.embed_dim[1],
embed_dim=cfg.embed_dim[2]
)
stride=2,
in_chans=cfg.embed_dim[1],
embed_dim=cfg.embed_dim[2]
)
def forward(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
for stage in self.stages:
x = checkpoint_seq(stage, x)
x = checkpoint(self.pool, x)
x = checkpoint_seq(self.stages, x)
else:
for stage in self.stages:
x = stage(x)
x = self.pool(x)
x = self.stages(x)
x = self.pool(x)
return x
class GeGluMlp(nn.Module):
def __init__(
self,
in_features,
hidden_features,
act_layer = None,
act_layer = 'gelu',
drop = 0.0,
):
super().__init__()
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
self.norm = norm_layer(in_features)
self.act = nn.GELU()
self.w0 = nn.Linear(in_features, hidden_features)
self.act = create_act_layer(act_layer)
self.w1 = nn.Linear(in_features, hidden_features)
self.w2 = nn.Linear(hidden_features, in_features)
@ -290,118 +278,116 @@ class GeGluMlp(nn.Module):
return x
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=1024,
bias=True,
dynamic_img_pad=False,
):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
with torch.no_grad():
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Identity()
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
if 'flexi' in variant:
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
_filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
else:
_filter_fn = checkpoint_filter_fn
def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs):
assert embed_cfg is not None
backbone = MbConvStages(cfg=embed_cfg)
kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return build_model_with_cfg(
VisionTransformer,
variant,
pretrained,
pretrained_filter_fn=_filter_fn,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs,
)
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
'first_conv': 'patch_embed.backbone.stem.conv1',
'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'vitamin_small.datacomp1b_clip_ltt': _cfg(
hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384),
'vitamin_small.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-S', num_classes=384),
'vitamin_base.datacomp1b_clip_ltt': _cfg(
hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768),
'vitamin_base.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-B', num_classes=768),
'vitamin_large.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=1024),
'vitamin_large_256.datacomp1b_clip_l2': _cfg(
hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024,
input_size=(3, 256, 256), crop_pct=1.0),
'vitamin_large_256.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=1024,
input_size=(3, 256, 256), crop_pct=1.0),
'vitamin_large_336.datacomp1b_clip_l2': _cfg(
hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024,
input_size=(3, 336, 336), crop_pct=1.0),
'vitamin_large_336.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=1024,
input_size=(3, 336, 336), crop_pct=1.0),
'vitamin_large_384.datacomp1b_clip_l2': _cfg(
hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024,
input_size=(3, 384, 384), crop_pct=1.0),
'vitamin_large_384.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=1024,
input_size=(3, 384, 384), crop_pct=1.0),
'vitamin_xlarge_256.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152,
input_size=(3, 256, 256), crop_pct=1.0),
'vitamin_xlarge_336.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152,
input_size=(3, 336, 336), crop_pct=1.0),
'vitamin_xlarge_384.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152,
input_size=(3, 384, 384), crop_pct=1.0),
})
@register_model
def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer:
stage_1_2 = MbConvStages(cfg=VitCfg(
embed_dim=(64, 128, 384),
depths=(2, 4, 1),
stem_width=64,
conv_cfg = VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
embed_cfg = VitCfg(
embed_dim=(64, 128, 384),
depths=(2, 4, 1),
stem_width=64,
conv_cfg = VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
model_args = dict(
embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg
)
model = _create_vitamin('vitamin_small', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer:
stage_1_2 = MbConvStages(cfg=VitCfg(
embed_dim=(128, 256, 768),
depths=(2, 4, 1),
stem_width=128,
conv_cfg = VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
embed_cfg = VitCfg(
embed_dim=(128, 256, 768),
depths=(2, 4, 1),
stem_width=128,
conv_cfg = VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
model_args = dict(
embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_base', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
stage_1_2 = MbConvStages(cfg=VitCfg(
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
@ -410,17 +396,18 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
norm_eps=1e-6,
),
head_type='1d',
),
)
stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid(
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
model_args = dict(
embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg,
)
model = _create_vitamin('vitamin_large', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
@ -429,17 +416,17 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
norm_eps=1e-6,
),
head_type='1d',
),
)
model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid(
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
model_args = dict(
img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
@ -448,17 +435,18 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
norm_eps=1e-6,
),
head_type='1d',
),
)
model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid(
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
model_args = dict(
img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg
)
model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
@ -467,17 +455,17 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
norm_eps=1e-6,
),
head_type='1d',
),
)
model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid(
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
model_args = dict(
img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
embed_cfg=VitCfg(
embed_dim=(192, 384, 1152),
depths=(2, 4, 1),
stem_width=192,
@ -486,17 +474,18 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
norm_eps=1e-6,
),
head_type='1d',
),
)
model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid(
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
model_args = dict(
img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
model = _create_vitamin(
'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
embed_cfg = VitCfg(
embed_dim=(192, 384, 1152),
depths=(2, 4, 1),
stem_width=192,
@ -505,17 +494,17 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
norm_eps=1e-6,
),
head_type='1d',
),
)
model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid(
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
model_args = dict(
img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_xlarge_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
embed_cfg = VitCfg(
embed_dim=(192, 384, 1152),
depths=(2, 4, 1),
stem_width=192,
@ -524,9 +513,9 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
norm_eps=1e-6,
),
head_type='1d',
),
)
model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
model = _create_vision_transformer_hybrid(
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
model_args = dict(
img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model