mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fixup ViTamin, add hub weight reference
This commit is contained in:
parent
b2c0aeb0ec
commit
1b66ec7cf3
@ -409,6 +409,7 @@ class VisionTransformer(nn.Module):
|
||||
qk_norm: bool = False,
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
pos_embed: str = 'learn',
|
||||
no_embed_class: bool = False,
|
||||
reg_tokens: int = 0,
|
||||
pre_norm: bool = False,
|
||||
@ -460,6 +461,7 @@ class VisionTransformer(nn.Module):
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token', 'map')
|
||||
assert class_token or global_pool != 'token'
|
||||
assert pos_embed in ('', 'none', 'learn')
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = get_act_layer(act_layer) or nn.GELU
|
||||
@ -494,7 +496,10 @@ class VisionTransformer(nn.Module):
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
if not pos_embed or pos_embed == 'none':
|
||||
self.pos_embed = None
|
||||
else:
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
if patch_drop_rate > 0:
|
||||
self.patch_drop = PatchDropout(
|
||||
@ -556,7 +561,8 @@ class VisionTransformer(nn.Module):
|
||||
def init_weights(self, mode: str = '') -> None:
|
||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(get_init_weights_vit(mode, head_bias), self)
|
||||
@ -583,6 +589,8 @@ class VisionTransformer(nn.Module):
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||
self.grad_checkpointing = enable
|
||||
if hasattr(self.patch_embed, 'set_grad_checkpointing'):
|
||||
self.patch_embed.set_grad_checkpointing(enable)
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self) -> nn.Module:
|
||||
@ -600,6 +608,9 @@ class VisionTransformer(nn.Module):
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is None:
|
||||
return x
|
||||
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
@ -1066,10 +1077,13 @@ def checkpoint_filter_fn(
|
||||
# IJEPA, vit in an 'encoder' submodule
|
||||
state_dict = state_dict['encoder']
|
||||
prefix = 'module.'
|
||||
elif 'visual.trunk.pos_embed' in state_dict:
|
||||
elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
|
||||
# OpenCLIP model with timm vision encoder
|
||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||
prefix = 'visual.trunk.'
|
||||
if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
|
||||
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
|
||||
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
||||
|
||||
if prefix:
|
||||
# filter on & remove prefix string from keys
|
||||
|
@ -38,14 +38,15 @@ class HybridEmbed(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
img_size=224,
|
||||
patch_size=1,
|
||||
feature_size=None,
|
||||
feature_ratio=None,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
bias=True,
|
||||
backbone: nn.Module,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias: bool = True,
|
||||
proj: bool = True,
|
||||
flatten: bool = True,
|
||||
output_fmt: Optional[str] = None,
|
||||
strict_img_size: bool = True,
|
||||
@ -95,7 +96,18 @@ class HybridEmbed(nn.Module):
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
if proj:
|
||||
self.proj = nn.Conv2d(
|
||||
feature_dim,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert feature_dim == embed_dim,\
|
||||
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
|
||||
self.proj = nn.Identity()
|
||||
|
||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||
total_reduction = (
|
||||
@ -116,6 +128,13 @@ class HybridEmbed(nn.Module):
|
||||
else:
|
||||
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||
self.backbone.set_grad_checkpointing(enable=enable)
|
||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||
self.backbone.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
@ -157,6 +176,13 @@ class HybridEmbedWithSize(nn.Module):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||
self.backbone.set_grad_checkpointing(enable=enable)
|
||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||
self.backbone.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
|
@ -19,29 +19,22 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision
|
||||
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass, replace, field
|
||||
from typing import Callable, Optional, Union, Tuple, List, Sequence
|
||||
import math, time
|
||||
from torch.jit import Final
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import timm
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
|
||||
|
||||
from timm.layers import to_2tuple
|
||||
from timm.layers import DropPath
|
||||
from timm.layers.norm_act import _create_act
|
||||
|
||||
from timm.models._manipulate import named_apply, checkpoint_seq
|
||||
from timm.models._builder import build_model_with_cfg
|
||||
from timm.models._registry import register_model
|
||||
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
|
||||
from timm.models.vision_transformer_hybrid import HybridEmbed
|
||||
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
|
||||
make_divisible, DropPath
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -90,24 +83,19 @@ class Stem(nn.Module):
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing=False
|
||||
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
||||
self.out_chs = out_chs
|
||||
|
||||
self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
|
||||
self.norm1 = norm_act_layer(out_chs)
|
||||
self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
|
||||
|
||||
named_apply(_init_conv, self)
|
||||
|
||||
def forward(self, x):
|
||||
if self.grad_checkpointing:
|
||||
x = checkpoint(self.conv1, x)
|
||||
x = self.norm1(x)
|
||||
x = checkpoint(self.conv2, x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -145,8 +133,9 @@ class StridedConv(nn.Module):
|
||||
embed_dim=768
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.norm = norm_layer(in_chans) # affine over C
|
||||
|
||||
def forward(self, x):
|
||||
@ -185,10 +174,10 @@ class MbConvLNBlock(nn.Module):
|
||||
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
|
||||
self.down = nn.Identity()
|
||||
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
|
||||
self.act1 = _create_act(act_layer, inplace=True)
|
||||
self.act2 = _create_act(act_layer, inplace=True)
|
||||
|
||||
self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
|
||||
self.act1 = create_act_layer(act_layer, inplace=True)
|
||||
self.conv2_kxk = create_conv2d(
|
||||
mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
|
||||
self.act2 = create_act_layer(act_layer, inplace=True)
|
||||
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
@ -228,58 +217,57 @@ class MbConvStages(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.stem = Stem(
|
||||
in_chs=in_chans,
|
||||
out_chs=cfg.stem_width,
|
||||
)
|
||||
|
||||
stages = []
|
||||
self.num_stages = len(cfg.embed_dim)
|
||||
for s, dim in enumerate(cfg.embed_dim[:2]): # stage
|
||||
blocks = []
|
||||
stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
|
||||
for d in range(cfg.depths[s]):
|
||||
blocks += [MbConvLNBlock(
|
||||
in_chs = stage_in_chs if d==0 else dim,
|
||||
out_chs = dim,
|
||||
stride = 2 if d == 0 else 1,
|
||||
# cfg = cfg.conv_cfg,
|
||||
)]
|
||||
blocks = nn.Sequential(*blocks)
|
||||
stages += [blocks]
|
||||
blocks = [
|
||||
MbConvLNBlock(
|
||||
in_chs = stage_in_chs if d==0 else dim,
|
||||
out_chs = dim,
|
||||
stride = 2 if d == 0 else 1,
|
||||
)
|
||||
for d in range(cfg.depths[s])
|
||||
]
|
||||
stages += [nn.Sequential(*blocks)]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.stages = nn.ModuleList(stages)
|
||||
self.pool = StridedConv(
|
||||
stride=2,
|
||||
in_chans=cfg.embed_dim[1],
|
||||
embed_dim=cfg.embed_dim[2]
|
||||
)
|
||||
stride=2,
|
||||
in_chans=cfg.embed_dim[1],
|
||||
embed_dim=cfg.embed_dim[2]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
for stage in self.stages:
|
||||
x = checkpoint_seq(stage, x)
|
||||
x = checkpoint(self.pool, x)
|
||||
x = checkpoint_seq(self.stages, x)
|
||||
else:
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.stages(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
|
||||
class GeGluMlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features,
|
||||
act_layer = None,
|
||||
act_layer = 'gelu',
|
||||
drop = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
|
||||
|
||||
self.norm = norm_layer(in_features)
|
||||
self.act = nn.GELU()
|
||||
self.w0 = nn.Linear(in_features, hidden_features)
|
||||
self.act = create_act_layer(act_layer)
|
||||
self.w1 = nn.Linear(in_features, hidden_features)
|
||||
self.w2 = nn.Linear(hidden_features, in_features)
|
||||
|
||||
@ -290,118 +278,116 @@ class GeGluMlp(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class HybridEmbed(nn.Module):
|
||||
""" CNN Feature Map Embedding
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
img_size=224,
|
||||
patch_size=1,
|
||||
feature_size=None,
|
||||
in_chans=3,
|
||||
embed_dim=1024,
|
||||
bias=True,
|
||||
dynamic_img_pad=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(backbone, nn.Module)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.backbone = backbone
|
||||
with torch.no_grad():
|
||||
training = backbone.training
|
||||
if training:
|
||||
backbone.eval()
|
||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
||||
if isinstance(o, (list, tuple)):
|
||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||
feature_size = o.shape[-2:]
|
||||
feature_dim = o.shape[1]
|
||||
backbone.train(training)
|
||||
|
||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.proj = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
if 'flexi' in variant:
|
||||
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
|
||||
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
|
||||
_filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
|
||||
else:
|
||||
_filter_fn = checkpoint_filter_fn
|
||||
def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs):
|
||||
assert embed_cfg is not None
|
||||
backbone = MbConvStages(cfg=embed_cfg)
|
||||
kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
|
||||
return build_model_with_cfg(
|
||||
VisionTransformer,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=_filter_fn,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
||||
embed_layer = partial(HybridEmbed, backbone=backbone)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
|
||||
'first_conv': 'patch_embed.backbone.stem.conv1',
|
||||
'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'vitamin_small.datacomp1b_clip_ltt': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384),
|
||||
'vitamin_small.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-S', num_classes=384),
|
||||
'vitamin_base.datacomp1b_clip_ltt': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768),
|
||||
'vitamin_base.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-B', num_classes=768),
|
||||
'vitamin_large.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=1024),
|
||||
'vitamin_large_256.datacomp1b_clip_l2': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024,
|
||||
input_size=(3, 256, 256), crop_pct=1.0),
|
||||
'vitamin_large_256.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=1024,
|
||||
input_size=(3, 256, 256), crop_pct=1.0),
|
||||
'vitamin_large_336.datacomp1b_clip_l2': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024,
|
||||
input_size=(3, 336, 336), crop_pct=1.0),
|
||||
'vitamin_large_336.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=1024,
|
||||
input_size=(3, 336, 336), crop_pct=1.0),
|
||||
'vitamin_large_384.datacomp1b_clip_l2': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024,
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vitamin_large_384.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=1024,
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vitamin_xlarge_256.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152,
|
||||
input_size=(3, 256, 256), crop_pct=1.0),
|
||||
'vitamin_xlarge_336.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152,
|
||||
input_size=(3, 336, 336), crop_pct=1.0),
|
||||
'vitamin_xlarge_384.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152,
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
stage_1_2 = MbConvStages(cfg=VitCfg(
|
||||
embed_dim=(64, 128, 384),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=64,
|
||||
conv_cfg = VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(64, 128, 384),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=64,
|
||||
conv_cfg = VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
||||
model_args = dict(
|
||||
embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg
|
||||
)
|
||||
model = _create_vitamin('vitamin_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
stage_1_2 = MbConvStages(cfg=VitCfg(
|
||||
embed_dim=(128, 256, 768),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=128,
|
||||
conv_cfg = VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(128, 256, 768),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=128,
|
||||
conv_cfg = VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
||||
model_args = dict(
|
||||
embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
stage_1_2 = MbConvStages(cfg=VitCfg(
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
@ -410,17 +396,18 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
),
|
||||
)
|
||||
stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
||||
model_args = dict(
|
||||
embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg,
|
||||
)
|
||||
model = _create_vitamin('vitamin_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
@ -429,17 +416,17 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
),
|
||||
)
|
||||
model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model_args = dict(
|
||||
img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
@ -448,17 +435,18 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
),
|
||||
)
|
||||
model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model_args = dict(
|
||||
img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg
|
||||
)
|
||||
model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
@ -467,17 +455,17 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
),
|
||||
)
|
||||
model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model_args = dict(
|
||||
img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
embed_cfg=VitCfg(
|
||||
embed_dim=(192, 384, 1152),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=192,
|
||||
@ -486,17 +474,18 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
),
|
||||
)
|
||||
model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model_args = dict(
|
||||
img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin(
|
||||
'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(192, 384, 1152),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=192,
|
||||
@ -505,17 +494,17 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
),
|
||||
)
|
||||
model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model_args = dict(
|
||||
img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_xlarge_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(192, 384, 1152),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=192,
|
||||
@ -524,9 +513,9 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
),
|
||||
)
|
||||
model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model_args = dict(
|
||||
img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
Loading…
x
Reference in New Issue
Block a user