diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 2dd4754a..f25db6b5 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -409,6 +409,7 @@ class VisionTransformer(nn.Module): qk_norm: bool = False, init_values: Optional[float] = None, class_token: bool = True, + pos_embed: str = 'learn', no_embed_class: bool = False, reg_tokens: int = 0, pre_norm: bool = False, @@ -460,6 +461,7 @@ class VisionTransformer(nn.Module): super().__init__() assert global_pool in ('', 'avg', 'token', 'map') assert class_token or global_pool != 'token' + assert pos_embed in ('', 'none', 'learn') use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU @@ -494,7 +496,10 @@ class VisionTransformer(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens - self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + if not pos_embed or pos_embed == 'none': + self.pos_embed = None + else: + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( @@ -556,7 +561,8 @@ class VisionTransformer(nn.Module): def init_weights(self, mode: str = '') -> None: assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - trunc_normal_(self.pos_embed, std=.02) + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) @@ -583,6 +589,8 @@ class VisionTransformer(nn.Module): @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: self.grad_checkpointing = enable + if hasattr(self.patch_embed, 'set_grad_checkpointing'): + self.patch_embed.set_grad_checkpointing(enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: @@ -600,6 +608,9 @@ class VisionTransformer(nn.Module): self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.pos_embed is None: + return x + if self.dynamic_img_size: B, H, W, C = x.shape pos_embed = resample_abs_pos_embed( @@ -1066,10 +1077,13 @@ def checkpoint_filter_fn( # IJEPA, vit in an 'encoder' submodule state_dict = state_dict['encoder'] prefix = 'module.' - elif 'visual.trunk.pos_embed' in state_dict: + elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict: # OpenCLIP model with timm vision encoder - # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) prefix = 'visual.trunk.' + if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear): + # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + out_dict['head.weight'] = state_dict['visual.head.proj.weight'] + out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) if prefix: # filter on & remove prefix string from keys diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 25dd9c27..c2dd1e59 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -38,14 +38,15 @@ class HybridEmbed(nn.Module): def __init__( self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - feature_ratio=None, - in_chans=3, - embed_dim=768, - bias=True, + backbone: nn.Module, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + in_chans: int = 3, + embed_dim: int = 768, + bias: bool = True, + proj: bool = True, flatten: bool = True, output_fmt: Optional[str] = None, strict_img_size: bool = True, @@ -95,7 +96,18 @@ class HybridEmbed(nn.Module): self.strict_img_size = strict_img_size self.dynamic_img_pad = dynamic_img_pad - self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + if proj: + self.proj = nn.Conv2d( + feature_dim, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + ) + else: + assert feature_dim == embed_dim,\ + f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.' + self.proj = nn.Identity() def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: total_reduction = ( @@ -116,6 +128,13 @@ class HybridEmbed(nn.Module): else: return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1] + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + if hasattr(self.backbone, 'set_grad_checkpointing'): + self.backbone.set_grad_checkpointing(enable=enable) + elif hasattr(self.backbone, 'grad_checkpointing'): + self.backbone.grad_checkpointing = enable + def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): @@ -157,6 +176,13 @@ class HybridEmbedWithSize(nn.Module): bias=bias, ) + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + if hasattr(self.backbone, 'set_grad_checkpointing'): + self.backbone.set_grad_checkpointing(enable=enable) + elif hasattr(self.backbone, 'grad_checkpointing'): + self.backbone.grad_checkpointing = enable + def forward(self, x) -> Tuple[torch.Tensor, List[int]]: x = self.backbone(x) if isinstance(x, (list, tuple)): diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index f84a59d6..71d3b674 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -19,29 +19,22 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py """ +import math +from dataclasses import dataclass, field from functools import partial -from typing import List, Tuple -from dataclasses import dataclass, replace, field -from typing import Callable, Optional, Union, Tuple, List, Sequence -import math, time -from torch.jit import Final +from typing import Optional, Union, Tuple + import torch import torch.nn as nn -import torch.nn.functional as F -import timm -from torch.utils.checkpoint import checkpoint -from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ - -from timm.layers import to_2tuple -from timm.layers import DropPath -from timm.layers.norm_act import _create_act - -from timm.models._manipulate import named_apply, checkpoint_seq -from timm.models._builder import build_model_with_cfg -from timm.models._registry import register_model -from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn -from timm.models.vision_transformer_hybrid import HybridEmbed +from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \ + make_divisible, DropPath +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model, generate_default_cfgs +from .vision_transformer import VisionTransformer, checkpoint_filter_fn +from .vision_transformer_hybrid import HybridEmbed @dataclass @@ -90,24 +83,19 @@ class Stem(nn.Module): bias: bool = True, ): super().__init__() - self.grad_checkpointing=False norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) self.out_chs = out_chs + self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias) self.norm1 = norm_act_layer(out_chs) self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias) + named_apply(_init_conv, self) def forward(self, x): - if self.grad_checkpointing: - x = checkpoint(self.conv1, x) - x = self.norm1(x) - x = checkpoint(self.conv2, x) - else: - x = self.conv1(x) - x = self.norm1(x) - x = self.conv2(x) - + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) return x @@ -145,8 +133,9 @@ class StridedConv(nn.Module): embed_dim=768 ): super().__init__() - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) self.norm = norm_layer(in_chans) # affine over C def forward(self, x): @@ -185,10 +174,10 @@ class MbConvLNBlock(nn.Module): self.pre_norm = prenorm_act_layer(in_chs, apply_act=False) self.down = nn.Identity() self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True) - self.act1 = _create_act(act_layer, inplace=True) - self.act2 = _create_act(act_layer, inplace=True) - - self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) + self.act1 = create_act_layer(act_layer, inplace=True) + self.conv2_kxk = create_conv2d( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) + self.act2 = create_act_layer(act_layer, inplace=True) self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -228,58 +217,57 @@ class MbConvStages(nn.Module): ): super().__init__() self.grad_checkpointing = False + self.stem = Stem( in_chs=in_chans, out_chs=cfg.stem_width, ) + stages = [] self.num_stages = len(cfg.embed_dim) for s, dim in enumerate(cfg.embed_dim[:2]): # stage - blocks = [] stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width - for d in range(cfg.depths[s]): - blocks += [MbConvLNBlock( - in_chs = stage_in_chs if d==0 else dim, - out_chs = dim, - stride = 2 if d == 0 else 1, - # cfg = cfg.conv_cfg, - )] - blocks = nn.Sequential(*blocks) - stages += [blocks] + blocks = [ + MbConvLNBlock( + in_chs = stage_in_chs if d==0 else dim, + out_chs = dim, + stride = 2 if d == 0 else 1, + ) + for d in range(cfg.depths[s]) + ] + stages += [nn.Sequential(*blocks)] + self.stages = nn.Sequential(*stages) - self.stages = nn.ModuleList(stages) self.pool = StridedConv( - stride=2, - in_chans=cfg.embed_dim[1], - embed_dim=cfg.embed_dim[2] - ) + stride=2, + in_chans=cfg.embed_dim[1], + embed_dim=cfg.embed_dim[2] + ) def forward(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - for stage in self.stages: - x = checkpoint_seq(stage, x) - x = checkpoint(self.pool, x) + x = checkpoint_seq(self.stages, x) else: - for stage in self.stages: - x = stage(x) - x = self.pool(x) - + x = self.stages(x) + x = self.pool(x) return x + class GeGluMlp(nn.Module): def __init__( self, in_features, hidden_features, - act_layer = None, + act_layer = 'gelu', drop = 0.0, ): super().__init__() norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + self.norm = norm_layer(in_features) - self.act = nn.GELU() self.w0 = nn.Linear(in_features, hidden_features) + self.act = create_act_layer(act_layer) self.w1 = nn.Linear(in_features, hidden_features) self.w2 = nn.Linear(hidden_features, in_features) @@ -290,118 +278,116 @@ class GeGluMlp(nn.Module): return x -class HybridEmbed(nn.Module): - """ CNN Feature Map Embedding - Extract feature map from CNN, flatten, project to embedding dim. - """ - def __init__( - self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - in_chans=3, - embed_dim=1024, - bias=True, - dynamic_img_pad=False, - ): - super().__init__() - assert isinstance(backbone, nn.Module) - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.backbone = backbone - with torch.no_grad(): - training = backbone.training - if training: - backbone.eval() - o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) - if isinstance(o, (list, tuple)): - o = o[-1] # last feature if backbone outputs list/tuple of features - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - - assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 - self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.proj = nn.Identity() - - def forward(self, x): - x = self.backbone(x) - if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features - x = self.proj(x) - x = x.flatten(2).transpose(1, 2) - return x - - -def _create_vision_transformer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - if 'flexi' in variant: - # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed - # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. - _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) - else: - _filter_fn = checkpoint_filter_fn +def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs): + assert embed_cfg is not None + backbone = MbConvStages(cfg=embed_cfg) + kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return build_model_with_cfg( VisionTransformer, variant, pretrained, - pretrained_filter_fn=_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, **kwargs, ) -def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - embed_layer = partial(HybridEmbed, backbone=backbone) - kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set - return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD, + 'first_conv': 'patch_embed.backbone.stem.conv1', + 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'vitamin_small.datacomp1b_clip_ltt': _cfg( + hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384), + 'vitamin_small.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-S', num_classes=384), + 'vitamin_base.datacomp1b_clip_ltt': _cfg( + hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768), + 'vitamin_base.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-B', num_classes=768), + 'vitamin_large.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=1024), + 'vitamin_large_256.datacomp1b_clip_l2': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_large_256.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=1024, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_large_336.datacomp1b_clip_l2': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_large_336.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=1024, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_large_384.datacomp1b_clip_l2': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024, + input_size=(3, 384, 384), crop_pct=1.0), + 'vitamin_large_384.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=1024, + input_size=(3, 384, 384), crop_pct=1.0), + 'vitamin_xlarge_256.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_xlarge_336.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_xlarge_384.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152, + input_size=(3, 384, 384), crop_pct=1.0), +}) @register_model def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer: - stage_1_2 = MbConvStages(cfg=VitCfg( - embed_dim=(64, 128, 384), - depths=(2, 4, 1), - stem_width=64, - conv_cfg = VitConvCfg( - norm_layer='layernorm2d', - norm_eps=1e-6, - ), - head_type='1d', + embed_cfg = VitCfg( + embed_dim=(64, 128, 384), + depths=(2, 4, 1), + stem_width=64, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, ), + head_type='1d', ) - stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + model_args = dict( + embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg + ) + model = _create_vitamin('vitamin_small', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer: - stage_1_2 = MbConvStages(cfg=VitCfg( - embed_dim=(128, 256, 768), - depths=(2, 4, 1), - stem_width=128, - conv_cfg = VitConvCfg( - norm_layer='layernorm2d', - norm_eps=1e-6, - ), - head_type='1d', + embed_cfg = VitCfg( + embed_dim=(128, 256, 768), + depths=(2, 4, 1), + stem_width=128, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, ), + head_type='1d', ) - stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + model_args = dict( + embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_base', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: - stage_1_2 = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, @@ -410,17 +396,18 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + model_args = dict( + embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg, + ) + model = _create_vitamin('vitamin_large', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, @@ -429,17 +416,17 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, @@ -448,17 +435,18 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg + ) + model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, @@ -467,17 +455,17 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg=VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, @@ -486,17 +474,18 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin( + 'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, @@ -505,17 +494,17 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_xlarge_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, @@ -524,9 +513,9 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model \ No newline at end of file