mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
renaming models and update checkpoint to vit-only
This commit is contained in:
parent
3af564f2fb
commit
b82f9e869c
@ -14,10 +14,27 @@ from torch.amp import autocast
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
### Import timm layers
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
|
||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||
get_act_layer, get_norm_layer, LayerType, LayerScale
|
||||
#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
|
||||
from timm.layers import (
|
||||
PatchEmbed,
|
||||
Mlp,
|
||||
DropPath,
|
||||
AttentionPoolLatent,
|
||||
RmsNorm,
|
||||
PatchDropout,
|
||||
SwiGLUPacked,
|
||||
SwiGLU,
|
||||
trunc_normal_,
|
||||
lecun_normal_,
|
||||
resample_patch_embed,
|
||||
resample_abs_pos_embed,
|
||||
use_fused_attn,
|
||||
get_act_layer,
|
||||
get_norm_layer,
|
||||
LayerType,
|
||||
LayerScale,
|
||||
)
|
||||
|
||||
# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
@ -28,20 +45,22 @@ from ._registry import generate_default_cfgs, register_model, register_model_dep
|
||||
__all__ = ['PE']
|
||||
|
||||
|
||||
####### PE's Rope ########
|
||||
|
||||
######## PE's Rope ########
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
@autocast("cuda", enabled=False)
|
||||
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
||||
dtype = t.dtype
|
||||
@ -73,9 +92,7 @@ class RotaryEmbedding(Module):
|
||||
self,
|
||||
dim,
|
||||
custom_freqs: Optional[Tensor] = None,
|
||||
freqs_for: Union[
|
||||
Literal["lang"], Literal["pixel"], Literal["constant"]
|
||||
] = "lang",
|
||||
freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang",
|
||||
theta=10000,
|
||||
max_freq=10,
|
||||
num_freqs=1,
|
||||
@ -99,9 +116,7 @@ class RotaryEmbedding(Module):
|
||||
if exists(custom_freqs):
|
||||
freqs = custom_freqs
|
||||
elif freqs_for == "lang":
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||
)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
elif freqs_for == "pixel":
|
||||
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
||||
elif freqs_for == "constant":
|
||||
@ -154,9 +169,7 @@ class RotaryEmbedding(Module):
|
||||
self.register_buffer(key, value, persistent=False)
|
||||
|
||||
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
||||
return (
|
||||
torch.arange(seq_len, device=device, dtype=dtype) + offset
|
||||
) / self.interpolate_factor
|
||||
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
||||
|
||||
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
|
||||
seq_dim = default(seq_dim, self.default_seq_dim)
|
||||
@ -184,9 +197,7 @@ class RotaryEmbedding(Module):
|
||||
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
||||
assert q_len <= k_len
|
||||
|
||||
rotated_q = self.rotate_queries_or_keys(
|
||||
q, seq_dim=seq_dim, offset=k_len - q_len + offset
|
||||
)
|
||||
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset)
|
||||
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
|
||||
|
||||
rotated_q = rotated_q.type(q.dtype)
|
||||
@ -222,11 +233,7 @@ class RotaryEmbedding(Module):
|
||||
|
||||
should_cache = self.cache_if_possible and exists(seq_len)
|
||||
|
||||
if (
|
||||
should_cache
|
||||
and exists(self.cached_scales)
|
||||
and (seq_len + offset) <= self.cached_scales.shape[0]
|
||||
):
|
||||
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]:
|
||||
return self.cached_scales[offset : (offset + seq_len)]
|
||||
|
||||
scale = 1.0
|
||||
@ -264,17 +271,10 @@ class RotaryEmbedding(Module):
|
||||
@autocast("cuda", enabled=False)
|
||||
def forward(self, t: Tensor, seq_len=None, offset=0):
|
||||
should_cache = (
|
||||
self.cache_if_possible
|
||||
and not self.learned_freq
|
||||
and exists(seq_len)
|
||||
and self.freqs_for != "pixel"
|
||||
self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel"
|
||||
)
|
||||
|
||||
if (
|
||||
should_cache
|
||||
and exists(self.cached_freqs)
|
||||
and (offset + seq_len) <= self.cached_freqs.shape[0]
|
||||
):
|
||||
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]:
|
||||
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
||||
|
||||
freqs = self.freqs
|
||||
@ -289,7 +289,7 @@ class RotaryEmbedding(Module):
|
||||
|
||||
|
||||
class Rope2D:
|
||||
""" Helper class to apply RoPE2D as well as interpolate on the fly. """
|
||||
"""Helper class to apply RoPE2D as well as interpolate on the fly."""
|
||||
|
||||
def __init__(self, dim, use_cls_token=False):
|
||||
self.dim = dim
|
||||
@ -313,13 +313,11 @@ class Rope2D:
|
||||
grid_y_range = torch.arange(grid_h, device=device)
|
||||
grid_x_range = torch.arange(grid_w, device=device)
|
||||
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
|
||||
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
||||
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
||||
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
||||
|
||||
if self.use_cls_token:
|
||||
freq = torch.cat(
|
||||
[freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0
|
||||
)
|
||||
freq = torch.cat([freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0)
|
||||
|
||||
self.freq = freq[None, ...]
|
||||
|
||||
@ -332,8 +330,8 @@ class Rope2D:
|
||||
|
||||
return q, k
|
||||
|
||||
####### PE's Modules ########
|
||||
|
||||
######## PE Modules ########
|
||||
class AttentionPooling(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -349,14 +347,10 @@ class AttentionPooling(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
assert (
|
||||
self.embed_dim % num_heads == 0
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
|
||||
self.attn = nn.MultiheadAttention(
|
||||
self.embed_dim, self.num_heads, batch_first=True
|
||||
)
|
||||
self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
|
||||
|
||||
self.layernorm = norm_layer(embed_dim)
|
||||
self.mlp_width = int(embed_dim * mlp_ratio)
|
||||
@ -396,9 +390,7 @@ class SelfAttention(nn.Module):
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
# To make this compatibile with nn.MultiHeadAttention
|
||||
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
||||
@ -418,13 +410,7 @@ class SelfAttention(nn.Module):
|
||||
proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
|
||||
|
||||
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
||||
proj = (
|
||||
proj.unflatten(-1, (3, embed_dim))
|
||||
.unsqueeze(0)
|
||||
.transpose(0, -2)
|
||||
.squeeze(-2)
|
||||
.contiguous()
|
||||
)
|
||||
proj = proj.unflatten(-1, (3, embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||
q, k, v = proj[0], proj[1], proj[2]
|
||||
|
||||
# Use "q_" so that we don't accidentally quit in pdb :)
|
||||
@ -435,9 +421,7 @@ class SelfAttention(nn.Module):
|
||||
if self.rope:
|
||||
q, k = self.rope(q, k)
|
||||
|
||||
attn = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
|
||||
)
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
|
||||
attn = rearrange(attn, "b h s d -> b s (h d)")
|
||||
|
||||
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
|
||||
@ -462,16 +446,8 @@ class ResidualAttentionBlock(nn.Module):
|
||||
else:
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
|
||||
|
||||
self.ls_1 = (
|
||||
LayerScale(d_model, ls_init_value)
|
||||
if ls_init_value is not None
|
||||
else nn.Identity()
|
||||
)
|
||||
self.ls_2 = (
|
||||
LayerScale(d_model, ls_init_value)
|
||||
if ls_init_value is not None
|
||||
else nn.Identity()
|
||||
)
|
||||
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
@ -511,9 +487,7 @@ class ResidualAttentionBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
x = x + self.drop_path1(
|
||||
self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))
|
||||
)
|
||||
x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)))
|
||||
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
|
||||
return x
|
||||
|
||||
@ -558,9 +532,9 @@ class Transformer(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def truncate(self, layer_idx: int):
|
||||
""" Delete layers so the last layer is the given layer index. """
|
||||
"""Delete layers so the last layer is the given layer index."""
|
||||
self.layers = ((self.layers + layer_idx) % self.layers) + 1
|
||||
self.resblocks = nn.ModuleList(self.resblocks[:self.layers])
|
||||
self.resblocks = nn.ModuleList(self.resblocks[: self.layers])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -576,7 +550,7 @@ class Transformer(nn.Module):
|
||||
x = checkpoint(r, x, None, None, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
|
||||
|
||||
if i == stop_idx:
|
||||
break
|
||||
|
||||
@ -604,7 +578,7 @@ class PE(nn.Module):
|
||||
output_dim: Optional[int] = 1280,
|
||||
attn_pooler_heads: int = 8,
|
||||
pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
|
||||
num_classes: int = 1000, # no use for now
|
||||
num_classes: int = 1000, # no use for now
|
||||
in_chans: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
@ -666,12 +640,11 @@ class PE(nn.Module):
|
||||
|
||||
self.init_tensors()
|
||||
|
||||
|
||||
def init_tensors(self):
|
||||
def init_submodule_tensors(module):
|
||||
for name, child in module.named_children():
|
||||
if hasattr(child, "init_tensors"):
|
||||
#logger.debug(f"Initializing tensors for submodule: {name}")
|
||||
# logger.debug(f"Initializing tensors for submodule: {name}")
|
||||
child.init_tensors()
|
||||
init_submodule_tensors(child)
|
||||
|
||||
@ -687,19 +660,14 @@ class PE(nn.Module):
|
||||
if self.use_abs_posemb:
|
||||
self.posemb_grid_size = self.image_size // self.patch_size
|
||||
self.positional_embedding = nn.Parameter(
|
||||
init_scale
|
||||
* torch.randn(
|
||||
int(self.use_cls_token) + self.posemb_grid_size**2, self.width
|
||||
)
|
||||
init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width)
|
||||
)
|
||||
|
||||
if self.proj_dim is not None:
|
||||
self.proj = nn.Parameter(
|
||||
init_scale * torch.randn(self.width, self.proj_dim)
|
||||
)
|
||||
self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim))
|
||||
|
||||
def truncate(self, layer_idx: int):
|
||||
""" Delete layers so the last layer is the given layer index. """
|
||||
"""Delete layers so the last layer is the given layer index."""
|
||||
self.transformer.truncate(layer_idx)
|
||||
self.layers = self.transformer.layers
|
||||
|
||||
@ -717,13 +685,9 @@ class PE(nn.Module):
|
||||
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
|
||||
|
||||
pos_embed = (
|
||||
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
|
||||
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
pos_embed = F.interpolate(pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False)
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
|
||||
|
||||
if self.use_cls_token:
|
||||
@ -743,13 +707,7 @@ class PE(nn.Module):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_features(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
norm: bool = False,
|
||||
layer_idx: int = -1,
|
||||
strip_cls_token: bool = False
|
||||
):
|
||||
def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
|
||||
batch, _, h, w = x.shape
|
||||
grid_h, grid_w = h // self.patch_size, w // self.patch_size
|
||||
|
||||
@ -790,14 +748,8 @@ class PE(nn.Module):
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: PE,
|
||||
adapt_layer_scale: bool = False,
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
import re
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
if any(k.startswith("visual.") for k in state_dict):
|
||||
@ -805,42 +757,33 @@ def checkpoint_filter_fn(
|
||||
return state_dict
|
||||
|
||||
|
||||
######## PE Config ########
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'license': 'apache-2.0',
|
||||
'num_classes': 0,
|
||||
'interpolation': 'bilinear',
|
||||
'fixed_input_size': True,
|
||||
'mean': IMAGENET_INCEPTION_MEAN,
|
||||
'mean': IMAGENET_INCEPTION_MEAN,
|
||||
'std': IMAGENET_INCEPTION_STD,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'pe_core_b16_224': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 224, 224)),
|
||||
'pe_core_l14_336': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 336, 336)),
|
||||
'pe_core_G14_448': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448)),
|
||||
'pe_lang_l14_448': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448)),
|
||||
'pe_lang_G14_448': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448)),
|
||||
'pe_spatial_G14_448': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448)),
|
||||
})
|
||||
|
||||
default_cfgs = generate_default_cfgs(
|
||||
{
|
||||
'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)),
|
||||
'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)),
|
||||
'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||
'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||
'vit_pe_lang_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||
'vit_pe_spatial_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
|
||||
return build_model_with_cfg(
|
||||
PE,
|
||||
variant,
|
||||
@ -851,99 +794,104 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@register_model
|
||||
def pe_core_b16_224(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size = 224,
|
||||
patch_size = 16,
|
||||
width = 768,
|
||||
layers = 12,
|
||||
heads = 12,
|
||||
mlp_ratio = 4.0,
|
||||
output_dim = 1024,
|
||||
use_cls_token = True,
|
||||
pool_type = 'attn',
|
||||
)
|
||||
return _create_pe('pe_core_b16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
@register_model
|
||||
def pe_core_l14_336(pretrained=False, **kwargs):
|
||||
def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size = 336,
|
||||
patch_size = 14,
|
||||
width = 1024,
|
||||
layers = 24,
|
||||
heads = 16,
|
||||
mlp_ratio = 4.0,
|
||||
output_dim = 1024,
|
||||
use_cls_token = True,
|
||||
pool_type = 'attn',
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
width=768,
|
||||
layers=12,
|
||||
heads=12,
|
||||
mlp_ratio=4.0,
|
||||
output_dim=1024,
|
||||
use_cls_token=True,
|
||||
pool_type='attn',
|
||||
)
|
||||
return _create_pe('pe_core_l14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pe_core_G14_448(pretrained=False, **kwargs):
|
||||
def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size = 448,
|
||||
patch_size = 14,
|
||||
width = 1536,
|
||||
layers = 50,
|
||||
heads = 16,
|
||||
mlp_ratio = 8960 / 1536,
|
||||
output_dim = 1280,
|
||||
use_cls_token = False,
|
||||
pool_type = 'attn',
|
||||
image_size=336,
|
||||
patch_size=14,
|
||||
width=1024,
|
||||
layers=24,
|
||||
heads=16,
|
||||
mlp_ratio=4.0,
|
||||
output_dim=1024,
|
||||
use_cls_token=True,
|
||||
pool_type='attn',
|
||||
)
|
||||
return _create_pe('pe_core_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pe_lang_G14_448(pretrained=False, **kwargs):
|
||||
def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size = 448,
|
||||
patch_size = 14,
|
||||
width = 1536,
|
||||
layers = 47,
|
||||
heads = 16,
|
||||
mlp_ratio = 8960 / 1536,
|
||||
output_dim = None,
|
||||
use_cls_token = False,
|
||||
use_ln_post = False,
|
||||
pool_type = 'none',
|
||||
ls_init_value = 0.1,
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
width=1536,
|
||||
layers=50,
|
||||
heads=16,
|
||||
mlp_ratio=8960 / 1536,
|
||||
output_dim=1280,
|
||||
use_cls_token=False,
|
||||
pool_type='attn',
|
||||
)
|
||||
return _create_pe('pe_lang_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pe_lang_l14_448(pretrained=False, **kwargs):
|
||||
def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size = 448,
|
||||
patch_size = 14,
|
||||
width = 1024,
|
||||
layers = 23,
|
||||
heads = 16,
|
||||
mlp_ratio = 4.0,
|
||||
output_dim = None,
|
||||
use_cls_token = True,
|
||||
use_ln_post = False,
|
||||
pool_type = 'none',
|
||||
ls_init_value = 0.1,
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
width=1024,
|
||||
layers=23,
|
||||
heads=16,
|
||||
mlp_ratio=4.0,
|
||||
output_dim=None,
|
||||
use_cls_token=True,
|
||||
use_ln_post=False,
|
||||
pool_type='none',
|
||||
ls_init_value=0.1,
|
||||
)
|
||||
return _create_pe('pe_lang_l14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return _create_pe('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pe_spatial_G14_448(pretrained=False, **kwargs):
|
||||
def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size = 448,
|
||||
patch_size = 14,
|
||||
width = 1536,
|
||||
layers = 50,
|
||||
heads = 16,
|
||||
mlp_ratio = 8960 / 1536,
|
||||
output_dim = None,
|
||||
use_cls_token = False,
|
||||
use_ln_post = False,
|
||||
pool_type = 'none',
|
||||
ls_init_value = 0.1,
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
width=1536,
|
||||
layers=47,
|
||||
heads=16,
|
||||
mlp_ratio=8960 / 1536,
|
||||
output_dim=None,
|
||||
use_cls_token=False,
|
||||
use_ln_post=False,
|
||||
pool_type='none',
|
||||
ls_init_value=0.1,
|
||||
)
|
||||
return _create_pe('pe_spatial_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return _create_pe('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
width=1536,
|
||||
layers=50,
|
||||
heads=16,
|
||||
mlp_ratio=8960 / 1536,
|
||||
output_dim=None,
|
||||
use_cls_token=False,
|
||||
use_ln_post=False,
|
||||
pool_type='none',
|
||||
ls_init_value=0.1,
|
||||
)
|
||||
return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
Loading…
x
Reference in New Issue
Block a user