mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fixup ViTamin, add hub weight reference
This commit is contained in:
parent
b2c0aeb0ec
commit
1b66ec7cf3
@ -409,6 +409,7 @@ class VisionTransformer(nn.Module):
|
|||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
class_token: bool = True,
|
class_token: bool = True,
|
||||||
|
pos_embed: str = 'learn',
|
||||||
no_embed_class: bool = False,
|
no_embed_class: bool = False,
|
||||||
reg_tokens: int = 0,
|
reg_tokens: int = 0,
|
||||||
pre_norm: bool = False,
|
pre_norm: bool = False,
|
||||||
@ -460,6 +461,7 @@ class VisionTransformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert global_pool in ('', 'avg', 'token', 'map')
|
assert global_pool in ('', 'avg', 'token', 'map')
|
||||||
assert class_token or global_pool != 'token'
|
assert class_token or global_pool != 'token'
|
||||||
|
assert pos_embed in ('', 'none', 'learn')
|
||||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||||
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||||
act_layer = get_act_layer(act_layer) or nn.GELU
|
act_layer = get_act_layer(act_layer) or nn.GELU
|
||||||
@ -494,7 +496,10 @@ class VisionTransformer(nn.Module):
|
|||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||||
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
if not pos_embed or pos_embed == 'none':
|
||||||
|
self.pos_embed = None
|
||||||
|
else:
|
||||||
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||||
if patch_drop_rate > 0:
|
if patch_drop_rate > 0:
|
||||||
self.patch_drop = PatchDropout(
|
self.patch_drop = PatchDropout(
|
||||||
@ -556,7 +561,8 @@ class VisionTransformer(nn.Module):
|
|||||||
def init_weights(self, mode: str = '') -> None:
|
def init_weights(self, mode: str = '') -> None:
|
||||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
if self.pos_embed is not None:
|
||||||
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
if self.cls_token is not None:
|
if self.cls_token is not None:
|
||||||
nn.init.normal_(self.cls_token, std=1e-6)
|
nn.init.normal_(self.cls_token, std=1e-6)
|
||||||
named_apply(get_init_weights_vit(mode, head_bias), self)
|
named_apply(get_init_weights_vit(mode, head_bias), self)
|
||||||
@ -583,6 +589,8 @@ class VisionTransformer(nn.Module):
|
|||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||||
self.grad_checkpointing = enable
|
self.grad_checkpointing = enable
|
||||||
|
if hasattr(self.patch_embed, 'set_grad_checkpointing'):
|
||||||
|
self.patch_embed.set_grad_checkpointing(enable)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def get_classifier(self) -> nn.Module:
|
def get_classifier(self) -> nn.Module:
|
||||||
@ -600,6 +608,9 @@ class VisionTransformer(nn.Module):
|
|||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.pos_embed is None:
|
||||||
|
return x
|
||||||
|
|
||||||
if self.dynamic_img_size:
|
if self.dynamic_img_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
pos_embed = resample_abs_pos_embed(
|
pos_embed = resample_abs_pos_embed(
|
||||||
@ -1066,10 +1077,13 @@ def checkpoint_filter_fn(
|
|||||||
# IJEPA, vit in an 'encoder' submodule
|
# IJEPA, vit in an 'encoder' submodule
|
||||||
state_dict = state_dict['encoder']
|
state_dict = state_dict['encoder']
|
||||||
prefix = 'module.'
|
prefix = 'module.'
|
||||||
elif 'visual.trunk.pos_embed' in state_dict:
|
elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
|
||||||
# OpenCLIP model with timm vision encoder
|
# OpenCLIP model with timm vision encoder
|
||||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
|
||||||
prefix = 'visual.trunk.'
|
prefix = 'visual.trunk.'
|
||||||
|
if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
|
||||||
|
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||||
|
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
|
||||||
|
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
||||||
|
|
||||||
if prefix:
|
if prefix:
|
||||||
# filter on & remove prefix string from keys
|
# filter on & remove prefix string from keys
|
||||||
|
@ -38,14 +38,15 @@ class HybridEmbed(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backbone,
|
backbone: nn.Module,
|
||||||
img_size=224,
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
patch_size=1,
|
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||||
feature_size=None,
|
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
feature_ratio=None,
|
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
in_chans=3,
|
in_chans: int = 3,
|
||||||
embed_dim=768,
|
embed_dim: int = 768,
|
||||||
bias=True,
|
bias: bool = True,
|
||||||
|
proj: bool = True,
|
||||||
flatten: bool = True,
|
flatten: bool = True,
|
||||||
output_fmt: Optional[str] = None,
|
output_fmt: Optional[str] = None,
|
||||||
strict_img_size: bool = True,
|
strict_img_size: bool = True,
|
||||||
@ -95,7 +96,18 @@ class HybridEmbed(nn.Module):
|
|||||||
self.strict_img_size = strict_img_size
|
self.strict_img_size = strict_img_size
|
||||||
self.dynamic_img_pad = dynamic_img_pad
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
|
||||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
if proj:
|
||||||
|
self.proj = nn.Conv2d(
|
||||||
|
feature_dim,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert feature_dim == embed_dim,\
|
||||||
|
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
|
||||||
|
self.proj = nn.Identity()
|
||||||
|
|
||||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||||
total_reduction = (
|
total_reduction = (
|
||||||
@ -116,6 +128,13 @@ class HybridEmbed(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable: bool = True):
|
||||||
|
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||||
|
self.backbone.set_grad_checkpointing(enable=enable)
|
||||||
|
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||||
|
self.backbone.grad_checkpointing = enable
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
@ -157,6 +176,13 @@ class HybridEmbedWithSize(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable: bool = True):
|
||||||
|
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||||
|
self.backbone.set_grad_checkpointing(enable=enable)
|
||||||
|
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||||
|
self.backbone.grad_checkpointing = enable
|
||||||
|
|
||||||
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
|
@ -19,29 +19,22 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision
|
|||||||
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
|
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Tuple
|
from typing import Optional, Union, Tuple
|
||||||
from dataclasses import dataclass, replace, field
|
|
||||||
from typing import Callable, Optional, Union, Tuple, List, Sequence
|
|
||||||
import math, time
|
|
||||||
from torch.jit import Final
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
import timm
|
|
||||||
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
|
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
|
||||||
|
make_divisible, DropPath
|
||||||
from timm.layers import to_2tuple
|
from ._builder import build_model_with_cfg
|
||||||
from timm.layers import DropPath
|
from ._manipulate import named_apply, checkpoint_seq
|
||||||
from timm.layers.norm_act import _create_act
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
|
||||||
from timm.models._manipulate import named_apply, checkpoint_seq
|
from .vision_transformer_hybrid import HybridEmbed
|
||||||
from timm.models._builder import build_model_with_cfg
|
|
||||||
from timm.models._registry import register_model
|
|
||||||
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
|
|
||||||
from timm.models.vision_transformer_hybrid import HybridEmbed
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -90,24 +83,19 @@ class Stem(nn.Module):
|
|||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.grad_checkpointing=False
|
|
||||||
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
||||||
self.out_chs = out_chs
|
self.out_chs = out_chs
|
||||||
|
|
||||||
self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
|
self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
|
||||||
self.norm1 = norm_act_layer(out_chs)
|
self.norm1 = norm_act_layer(out_chs)
|
||||||
self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
|
self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
|
||||||
|
|
||||||
named_apply(_init_conv, self)
|
named_apply(_init_conv, self)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.grad_checkpointing:
|
x = self.conv1(x)
|
||||||
x = checkpoint(self.conv1, x)
|
x = self.norm1(x)
|
||||||
x = self.norm1(x)
|
x = self.conv2(x)
|
||||||
x = checkpoint(self.conv2, x)
|
|
||||||
else:
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -145,8 +133,9 @@ class StridedConv(nn.Module):
|
|||||||
embed_dim=768
|
embed_dim=768
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
|
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
self.norm = norm_layer(in_chans) # affine over C
|
self.norm = norm_layer(in_chans) # affine over C
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -185,10 +174,10 @@ class MbConvLNBlock(nn.Module):
|
|||||||
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
|
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
|
||||||
self.down = nn.Identity()
|
self.down = nn.Identity()
|
||||||
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
|
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
|
||||||
self.act1 = _create_act(act_layer, inplace=True)
|
self.act1 = create_act_layer(act_layer, inplace=True)
|
||||||
self.act2 = _create_act(act_layer, inplace=True)
|
self.conv2_kxk = create_conv2d(
|
||||||
|
mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
|
||||||
self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
|
self.act2 = create_act_layer(act_layer, inplace=True)
|
||||||
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
|
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
@ -228,58 +217,57 @@ class MbConvStages(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
self.stem = Stem(
|
self.stem = Stem(
|
||||||
in_chs=in_chans,
|
in_chs=in_chans,
|
||||||
out_chs=cfg.stem_width,
|
out_chs=cfg.stem_width,
|
||||||
)
|
)
|
||||||
|
|
||||||
stages = []
|
stages = []
|
||||||
self.num_stages = len(cfg.embed_dim)
|
self.num_stages = len(cfg.embed_dim)
|
||||||
for s, dim in enumerate(cfg.embed_dim[:2]): # stage
|
for s, dim in enumerate(cfg.embed_dim[:2]): # stage
|
||||||
blocks = []
|
|
||||||
stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
|
stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
|
||||||
for d in range(cfg.depths[s]):
|
blocks = [
|
||||||
blocks += [MbConvLNBlock(
|
MbConvLNBlock(
|
||||||
in_chs = stage_in_chs if d==0 else dim,
|
in_chs = stage_in_chs if d==0 else dim,
|
||||||
out_chs = dim,
|
out_chs = dim,
|
||||||
stride = 2 if d == 0 else 1,
|
stride = 2 if d == 0 else 1,
|
||||||
# cfg = cfg.conv_cfg,
|
)
|
||||||
)]
|
for d in range(cfg.depths[s])
|
||||||
blocks = nn.Sequential(*blocks)
|
]
|
||||||
stages += [blocks]
|
stages += [nn.Sequential(*blocks)]
|
||||||
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
self.stages = nn.ModuleList(stages)
|
|
||||||
self.pool = StridedConv(
|
self.pool = StridedConv(
|
||||||
stride=2,
|
stride=2,
|
||||||
in_chans=cfg.embed_dim[1],
|
in_chans=cfg.embed_dim[1],
|
||||||
embed_dim=cfg.embed_dim[2]
|
embed_dim=cfg.embed_dim[2]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
for stage in self.stages:
|
x = checkpoint_seq(self.stages, x)
|
||||||
x = checkpoint_seq(stage, x)
|
|
||||||
x = checkpoint(self.pool, x)
|
|
||||||
else:
|
else:
|
||||||
for stage in self.stages:
|
x = self.stages(x)
|
||||||
x = stage(x)
|
x = self.pool(x)
|
||||||
x = self.pool(x)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class GeGluMlp(nn.Module):
|
class GeGluMlp(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features,
|
in_features,
|
||||||
hidden_features,
|
hidden_features,
|
||||||
act_layer = None,
|
act_layer = 'gelu',
|
||||||
drop = 0.0,
|
drop = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
|
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
|
||||||
|
|
||||||
self.norm = norm_layer(in_features)
|
self.norm = norm_layer(in_features)
|
||||||
self.act = nn.GELU()
|
|
||||||
self.w0 = nn.Linear(in_features, hidden_features)
|
self.w0 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = create_act_layer(act_layer)
|
||||||
self.w1 = nn.Linear(in_features, hidden_features)
|
self.w1 = nn.Linear(in_features, hidden_features)
|
||||||
self.w2 = nn.Linear(hidden_features, in_features)
|
self.w2 = nn.Linear(hidden_features, in_features)
|
||||||
|
|
||||||
@ -290,118 +278,116 @@ class GeGluMlp(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class HybridEmbed(nn.Module):
|
def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs):
|
||||||
""" CNN Feature Map Embedding
|
assert embed_cfg is not None
|
||||||
Extract feature map from CNN, flatten, project to embedding dim.
|
backbone = MbConvStages(cfg=embed_cfg)
|
||||||
"""
|
kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False)
|
||||||
def __init__(
|
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||||
self,
|
|
||||||
backbone,
|
|
||||||
img_size=224,
|
|
||||||
patch_size=1,
|
|
||||||
feature_size=None,
|
|
||||||
in_chans=3,
|
|
||||||
embed_dim=1024,
|
|
||||||
bias=True,
|
|
||||||
dynamic_img_pad=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
assert isinstance(backbone, nn.Module)
|
|
||||||
img_size = to_2tuple(img_size)
|
|
||||||
patch_size = to_2tuple(patch_size)
|
|
||||||
self.img_size = img_size
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.backbone = backbone
|
|
||||||
with torch.no_grad():
|
|
||||||
training = backbone.training
|
|
||||||
if training:
|
|
||||||
backbone.eval()
|
|
||||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
|
||||||
if isinstance(o, (list, tuple)):
|
|
||||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
feature_size = o.shape[-2:]
|
|
||||||
feature_dim = o.shape[1]
|
|
||||||
backbone.train(training)
|
|
||||||
|
|
||||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
|
||||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
|
||||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
||||||
self.proj = nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.backbone(x)
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
x = self.proj(x)
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|
||||||
if kwargs.get('features_only', None):
|
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
||||||
|
|
||||||
if 'flexi' in variant:
|
|
||||||
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
|
|
||||||
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
|
|
||||||
_filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
|
|
||||||
else:
|
|
||||||
_filter_fn = checkpoint_filter_fn
|
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
VisionTransformer,
|
VisionTransformer,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_filter_fn=_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
embed_layer = partial(HybridEmbed, backbone=backbone)
|
return {
|
||||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
'url': url,
|
||||||
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||||
|
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
|
||||||
|
'first_conv': 'patch_embed.backbone.stem.conv1',
|
||||||
|
'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = generate_default_cfgs({
|
||||||
|
'vitamin_small.datacomp1b_clip_ltt': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384),
|
||||||
|
'vitamin_small.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-S', num_classes=384),
|
||||||
|
'vitamin_base.datacomp1b_clip_ltt': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768),
|
||||||
|
'vitamin_base.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-B', num_classes=768),
|
||||||
|
'vitamin_large.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=1024),
|
||||||
|
'vitamin_large_256.datacomp1b_clip_l2': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024,
|
||||||
|
input_size=(3, 256, 256), crop_pct=1.0),
|
||||||
|
'vitamin_large_256.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=1024,
|
||||||
|
input_size=(3, 256, 256), crop_pct=1.0),
|
||||||
|
'vitamin_large_336.datacomp1b_clip_l2': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024,
|
||||||
|
input_size=(3, 336, 336), crop_pct=1.0),
|
||||||
|
'vitamin_large_336.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=1024,
|
||||||
|
input_size=(3, 336, 336), crop_pct=1.0),
|
||||||
|
'vitamin_large_384.datacomp1b_clip_l2': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024,
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'vitamin_large_384.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=1024,
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'vitamin_xlarge_256.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152,
|
||||||
|
input_size=(3, 256, 256), crop_pct=1.0),
|
||||||
|
'vitamin_xlarge_336.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152,
|
||||||
|
input_size=(3, 336, 336), crop_pct=1.0),
|
||||||
|
'vitamin_xlarge_384.datacomp1b_clip': _cfg(
|
||||||
|
hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152,
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
stage_1_2 = MbConvStages(cfg=VitCfg(
|
embed_cfg = VitCfg(
|
||||||
embed_dim=(64, 128, 384),
|
embed_dim=(64, 128, 384),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=64,
|
stem_width=64,
|
||||||
conv_cfg = VitConvCfg(
|
conv_cfg = VitConvCfg(
|
||||||
norm_layer='layernorm2d',
|
norm_layer='layernorm2d',
|
||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
|
||||||
head_type='1d',
|
|
||||||
),
|
),
|
||||||
|
head_type='1d',
|
||||||
)
|
)
|
||||||
stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
|
class_token=False, global_pool='avg', embed_cfg=embed_cfg
|
||||||
|
)
|
||||||
|
model = _create_vitamin('vitamin_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
stage_1_2 = MbConvStages(cfg=VitCfg(
|
embed_cfg = VitCfg(
|
||||||
embed_dim=(128, 256, 768),
|
embed_dim=(128, 256, 768),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=128,
|
stem_width=128,
|
||||||
conv_cfg = VitConvCfg(
|
conv_cfg = VitConvCfg(
|
||||||
norm_layer='layernorm2d',
|
norm_layer='layernorm2d',
|
||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
|
||||||
head_type='1d',
|
|
||||||
),
|
),
|
||||||
|
head_type='1d',
|
||||||
)
|
)
|
||||||
stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
|
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||||
|
model = _create_vitamin('vitamin_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
stage_1_2 = MbConvStages(cfg=VitCfg(
|
embed_cfg = VitCfg(
|
||||||
embed_dim=(160, 320, 1024),
|
embed_dim=(160, 320, 1024),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=160,
|
stem_width=160,
|
||||||
@ -410,17 +396,18 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
),
|
||||||
head_type='1d',
|
head_type='1d',
|
||||||
),
|
|
||||||
)
|
)
|
||||||
stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid(
|
embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
class_token=False, global_pool='avg', embed_cfg=embed_cfg,
|
||||||
|
)
|
||||||
|
model = _create_vitamin('vitamin_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
embed_cfg = VitCfg(
|
||||||
embed_dim=(160, 320, 1024),
|
embed_dim=(160, 320, 1024),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=160,
|
stem_width=160,
|
||||||
@ -429,17 +416,17 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
),
|
||||||
head_type='1d',
|
head_type='1d',
|
||||||
),
|
|
||||||
)
|
)
|
||||||
model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid(
|
img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||||
|
model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
embed_cfg = VitCfg(
|
||||||
embed_dim=(160, 320, 1024),
|
embed_dim=(160, 320, 1024),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=160,
|
stem_width=160,
|
||||||
@ -448,17 +435,18 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
),
|
||||||
head_type='1d',
|
head_type='1d',
|
||||||
),
|
|
||||||
)
|
)
|
||||||
model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid(
|
img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
class_token=False, global_pool='avg', embed_cfg=embed_cfg
|
||||||
|
)
|
||||||
|
model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
embed_cfg = VitCfg(
|
||||||
embed_dim=(160, 320, 1024),
|
embed_dim=(160, 320, 1024),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=160,
|
stem_width=160,
|
||||||
@ -467,17 +455,17 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
),
|
||||||
head_type='1d',
|
head_type='1d',
|
||||||
),
|
|
||||||
)
|
)
|
||||||
model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid(
|
img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||||
|
model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
embed_cfg=VitCfg(
|
||||||
embed_dim=(192, 384, 1152),
|
embed_dim=(192, 384, 1152),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=192,
|
stem_width=192,
|
||||||
@ -486,17 +474,18 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
),
|
||||||
head_type='1d',
|
head_type='1d',
|
||||||
),
|
|
||||||
)
|
)
|
||||||
model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid(
|
img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||||
|
model = _create_vitamin(
|
||||||
|
'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
embed_cfg = VitCfg(
|
||||||
embed_dim=(192, 384, 1152),
|
embed_dim=(192, 384, 1152),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=192,
|
stem_width=192,
|
||||||
@ -505,17 +494,17 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
),
|
||||||
head_type='1d',
|
head_type='1d',
|
||||||
),
|
|
||||||
)
|
)
|
||||||
model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid(
|
img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||||
|
model = _create_vitamin('vitamin_xlarge_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
embed_cfg = VitCfg(
|
||||||
embed_dim=(192, 384, 1152),
|
embed_dim=(192, 384, 1152),
|
||||||
depths=(2, 4, 1),
|
depths=(2, 4, 1),
|
||||||
stem_width=192,
|
stem_width=192,
|
||||||
@ -524,9 +513,9 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
),
|
),
|
||||||
head_type='1d',
|
head_type='1d',
|
||||||
),
|
|
||||||
)
|
)
|
||||||
model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
model_args = dict(
|
||||||
model = _create_vision_transformer_hybrid(
|
img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||||
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||||
|
model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
Loading…
x
Reference in New Issue
Block a user