mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
renaming models and update checkpoint to vit-only
This commit is contained in:
parent
3af564f2fb
commit
b82f9e869c
@ -14,9 +14,26 @@ from torch.amp import autocast
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
### Import timm layers
|
### Import timm layers
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
|
from timm.layers import (
|
||||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
PatchEmbed,
|
||||||
get_act_layer, get_norm_layer, LayerType, LayerScale
|
Mlp,
|
||||||
|
DropPath,
|
||||||
|
AttentionPoolLatent,
|
||||||
|
RmsNorm,
|
||||||
|
PatchDropout,
|
||||||
|
SwiGLUPacked,
|
||||||
|
SwiGLU,
|
||||||
|
trunc_normal_,
|
||||||
|
lecun_normal_,
|
||||||
|
resample_patch_embed,
|
||||||
|
resample_abs_pos_embed,
|
||||||
|
use_fused_attn,
|
||||||
|
get_act_layer,
|
||||||
|
get_norm_layer,
|
||||||
|
LayerType,
|
||||||
|
LayerScale,
|
||||||
|
)
|
||||||
|
|
||||||
# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
|
# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
|
||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
|
||||||
@ -28,20 +45,22 @@ from ._registry import generate_default_cfgs, register_model, register_model_dep
|
|||||||
__all__ = ['PE']
|
__all__ = ['PE']
|
||||||
|
|
||||||
|
|
||||||
####### PE's Rope ########
|
######## PE's Rope ########
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||||
x1, x2 = x.unbind(dim=-1)
|
x1, x2 = x.unbind(dim=-1)
|
||||||
x = torch.stack((-x2, x1), dim=-1)
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
return rearrange(x, "... d r -> ... (d r)")
|
return rearrange(x, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
|
||||||
@autocast("cuda", enabled=False)
|
@autocast("cuda", enabled=False)
|
||||||
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
||||||
dtype = t.dtype
|
dtype = t.dtype
|
||||||
@ -73,9 +92,7 @@ class RotaryEmbedding(Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
custom_freqs: Optional[Tensor] = None,
|
custom_freqs: Optional[Tensor] = None,
|
||||||
freqs_for: Union[
|
freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang",
|
||||||
Literal["lang"], Literal["pixel"], Literal["constant"]
|
|
||||||
] = "lang",
|
|
||||||
theta=10000,
|
theta=10000,
|
||||||
max_freq=10,
|
max_freq=10,
|
||||||
num_freqs=1,
|
num_freqs=1,
|
||||||
@ -99,9 +116,7 @@ class RotaryEmbedding(Module):
|
|||||||
if exists(custom_freqs):
|
if exists(custom_freqs):
|
||||||
freqs = custom_freqs
|
freqs = custom_freqs
|
||||||
elif freqs_for == "lang":
|
elif freqs_for == "lang":
|
||||||
freqs = 1.0 / (
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
|
||||||
)
|
|
||||||
elif freqs_for == "pixel":
|
elif freqs_for == "pixel":
|
||||||
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
||||||
elif freqs_for == "constant":
|
elif freqs_for == "constant":
|
||||||
@ -154,9 +169,7 @@ class RotaryEmbedding(Module):
|
|||||||
self.register_buffer(key, value, persistent=False)
|
self.register_buffer(key, value, persistent=False)
|
||||||
|
|
||||||
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
||||||
return (
|
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
||||||
torch.arange(seq_len, device=device, dtype=dtype) + offset
|
|
||||||
) / self.interpolate_factor
|
|
||||||
|
|
||||||
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
|
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
|
||||||
seq_dim = default(seq_dim, self.default_seq_dim)
|
seq_dim = default(seq_dim, self.default_seq_dim)
|
||||||
@ -184,9 +197,7 @@ class RotaryEmbedding(Module):
|
|||||||
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
||||||
assert q_len <= k_len
|
assert q_len <= k_len
|
||||||
|
|
||||||
rotated_q = self.rotate_queries_or_keys(
|
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset)
|
||||||
q, seq_dim=seq_dim, offset=k_len - q_len + offset
|
|
||||||
)
|
|
||||||
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
|
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
|
||||||
|
|
||||||
rotated_q = rotated_q.type(q.dtype)
|
rotated_q = rotated_q.type(q.dtype)
|
||||||
@ -222,11 +233,7 @@ class RotaryEmbedding(Module):
|
|||||||
|
|
||||||
should_cache = self.cache_if_possible and exists(seq_len)
|
should_cache = self.cache_if_possible and exists(seq_len)
|
||||||
|
|
||||||
if (
|
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]:
|
||||||
should_cache
|
|
||||||
and exists(self.cached_scales)
|
|
||||||
and (seq_len + offset) <= self.cached_scales.shape[0]
|
|
||||||
):
|
|
||||||
return self.cached_scales[offset : (offset + seq_len)]
|
return self.cached_scales[offset : (offset + seq_len)]
|
||||||
|
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
@ -264,17 +271,10 @@ class RotaryEmbedding(Module):
|
|||||||
@autocast("cuda", enabled=False)
|
@autocast("cuda", enabled=False)
|
||||||
def forward(self, t: Tensor, seq_len=None, offset=0):
|
def forward(self, t: Tensor, seq_len=None, offset=0):
|
||||||
should_cache = (
|
should_cache = (
|
||||||
self.cache_if_possible
|
self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel"
|
||||||
and not self.learned_freq
|
|
||||||
and exists(seq_len)
|
|
||||||
and self.freqs_for != "pixel"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]:
|
||||||
should_cache
|
|
||||||
and exists(self.cached_freqs)
|
|
||||||
and (offset + seq_len) <= self.cached_freqs.shape[0]
|
|
||||||
):
|
|
||||||
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
||||||
|
|
||||||
freqs = self.freqs
|
freqs = self.freqs
|
||||||
@ -317,9 +317,7 @@ class Rope2D:
|
|||||||
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
||||||
|
|
||||||
if self.use_cls_token:
|
if self.use_cls_token:
|
||||||
freq = torch.cat(
|
freq = torch.cat([freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0)
|
||||||
[freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
self.freq = freq[None, ...]
|
self.freq = freq[None, ...]
|
||||||
|
|
||||||
@ -332,8 +330,8 @@ class Rope2D:
|
|||||||
|
|
||||||
return q, k
|
return q, k
|
||||||
|
|
||||||
####### PE's Modules ########
|
|
||||||
|
|
||||||
|
######## PE Modules ########
|
||||||
class AttentionPooling(nn.Module):
|
class AttentionPooling(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -349,14 +347,10 @@ class AttentionPooling(nn.Module):
|
|||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|
||||||
assert (
|
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||||
self.embed_dim % num_heads == 0
|
|
||||||
), "embed_dim must be divisible by num_heads"
|
|
||||||
|
|
||||||
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
|
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
|
||||||
self.attn = nn.MultiheadAttention(
|
self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
|
||||||
self.embed_dim, self.num_heads, batch_first=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.layernorm = norm_layer(embed_dim)
|
self.layernorm = norm_layer(embed_dim)
|
||||||
self.mlp_width = int(embed_dim * mlp_ratio)
|
self.mlp_width = int(embed_dim * mlp_ratio)
|
||||||
@ -396,9 +390,7 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
assert (
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
self.head_dim * num_heads == self.embed_dim
|
|
||||||
), "embed_dim must be divisible by num_heads"
|
|
||||||
|
|
||||||
# To make this compatibile with nn.MultiHeadAttention
|
# To make this compatibile with nn.MultiHeadAttention
|
||||||
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
||||||
@ -418,13 +410,7 @@ class SelfAttention(nn.Module):
|
|||||||
proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
|
proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
|
||||||
|
|
||||||
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
||||||
proj = (
|
proj = proj.unflatten(-1, (3, embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||||
proj.unflatten(-1, (3, embed_dim))
|
|
||||||
.unsqueeze(0)
|
|
||||||
.transpose(0, -2)
|
|
||||||
.squeeze(-2)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
q, k, v = proj[0], proj[1], proj[2]
|
q, k, v = proj[0], proj[1], proj[2]
|
||||||
|
|
||||||
# Use "q_" so that we don't accidentally quit in pdb :)
|
# Use "q_" so that we don't accidentally quit in pdb :)
|
||||||
@ -435,9 +421,7 @@ class SelfAttention(nn.Module):
|
|||||||
if self.rope:
|
if self.rope:
|
||||||
q, k = self.rope(q, k)
|
q, k = self.rope(q, k)
|
||||||
|
|
||||||
attn = F.scaled_dot_product_attention(
|
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
|
||||||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
|
|
||||||
)
|
|
||||||
attn = rearrange(attn, "b h s d -> b s (h d)")
|
attn = rearrange(attn, "b h s d -> b s (h d)")
|
||||||
|
|
||||||
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
|
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
|
||||||
@ -462,16 +446,8 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
|
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
|
||||||
|
|
||||||
self.ls_1 = (
|
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
LayerScale(d_model, ls_init_value)
|
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||||
if ls_init_value is not None
|
|
||||||
else nn.Identity()
|
|
||||||
)
|
|
||||||
self.ls_2 = (
|
|
||||||
LayerScale(d_model, ls_init_value)
|
|
||||||
if ls_init_value is not None
|
|
||||||
else nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ln_1 = norm_layer(d_model)
|
self.ln_1 = norm_layer(d_model)
|
||||||
self.ln_2 = norm_layer(d_model)
|
self.ln_2 = norm_layer(d_model)
|
||||||
@ -511,9 +487,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
x = x + self.drop_path1(
|
x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)))
|
||||||
self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))
|
|
||||||
)
|
|
||||||
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
|
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -666,7 +640,6 @@ class PE(nn.Module):
|
|||||||
|
|
||||||
self.init_tensors()
|
self.init_tensors()
|
||||||
|
|
||||||
|
|
||||||
def init_tensors(self):
|
def init_tensors(self):
|
||||||
def init_submodule_tensors(module):
|
def init_submodule_tensors(module):
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
@ -687,16 +660,11 @@ class PE(nn.Module):
|
|||||||
if self.use_abs_posemb:
|
if self.use_abs_posemb:
|
||||||
self.posemb_grid_size = self.image_size // self.patch_size
|
self.posemb_grid_size = self.image_size // self.patch_size
|
||||||
self.positional_embedding = nn.Parameter(
|
self.positional_embedding = nn.Parameter(
|
||||||
init_scale
|
init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width)
|
||||||
* torch.randn(
|
|
||||||
int(self.use_cls_token) + self.posemb_grid_size**2, self.width
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.proj_dim is not None:
|
if self.proj_dim is not None:
|
||||||
self.proj = nn.Parameter(
|
self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim))
|
||||||
init_scale * torch.randn(self.width, self.proj_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def truncate(self, layer_idx: int):
|
def truncate(self, layer_idx: int):
|
||||||
"""Delete layers so the last layer is the given layer index."""
|
"""Delete layers so the last layer is the given layer index."""
|
||||||
@ -717,13 +685,9 @@ class PE(nn.Module):
|
|||||||
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
|
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
|
||||||
|
|
||||||
pos_embed = (
|
pos_embed = (
|
||||||
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
|
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||||
.permute(0, 3, 1, 2)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
pos_embed = F.interpolate(
|
|
||||||
pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
|
|
||||||
)
|
)
|
||||||
|
pos_embed = F.interpolate(pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False)
|
||||||
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
|
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
|
||||||
|
|
||||||
if self.use_cls_token:
|
if self.use_cls_token:
|
||||||
@ -743,13 +707,7 @@ class PE(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward_features(
|
def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
norm: bool = False,
|
|
||||||
layer_idx: int = -1,
|
|
||||||
strip_cls_token: bool = False
|
|
||||||
):
|
|
||||||
batch, _, h, w = x.shape
|
batch, _, h, w = x.shape
|
||||||
grid_h, grid_w = h // self.patch_size, w // self.patch_size
|
grid_h, grid_w = h // self.patch_size, w // self.patch_size
|
||||||
|
|
||||||
@ -791,13 +749,7 @@ class PE(nn.Module):
|
|||||||
|
|
||||||
def checkpoint_filter_fn(
|
def checkpoint_filter_fn(
|
||||||
state_dict: Dict[str, torch.Tensor],
|
state_dict: Dict[str, torch.Tensor],
|
||||||
model: PE,
|
|
||||||
adapt_layer_scale: bool = False,
|
|
||||||
interpolation: str = 'bicubic',
|
|
||||||
antialias: bool = True,
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
||||||
import re
|
|
||||||
state_dict = state_dict.get('model', state_dict)
|
state_dict = state_dict.get('model', state_dict)
|
||||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||||
if any(k.startswith("visual.") for k in state_dict):
|
if any(k.startswith("visual.") for k in state_dict):
|
||||||
@ -805,6 +757,7 @@ def checkpoint_filter_fn(
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
######## PE Config ########
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'license': 'apache-2.0',
|
'license': 'apache-2.0',
|
||||||
@ -813,34 +766,24 @@ def _cfg(url='', **kwargs):
|
|||||||
'fixed_input_size': True,
|
'fixed_input_size': True,
|
||||||
'mean': IMAGENET_INCEPTION_MEAN,
|
'mean': IMAGENET_INCEPTION_MEAN,
|
||||||
'std': IMAGENET_INCEPTION_STD,
|
'std': IMAGENET_INCEPTION_STD,
|
||||||
**kwargs
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
|
||||||
'pe_core_b16_224': _cfg(
|
default_cfgs = generate_default_cfgs(
|
||||||
hf_hub_id='timm/',
|
{
|
||||||
input_size=(3, 224, 224)),
|
'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)),
|
||||||
'pe_core_l14_336': _cfg(
|
'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)),
|
||||||
hf_hub_id='timm/',
|
'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||||
input_size=(3, 336, 336)),
|
'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||||
'pe_core_G14_448': _cfg(
|
'vit_pe_lang_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||||
hf_hub_id='timm/',
|
'vit_pe_spatial_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||||
input_size=(3, 448, 448)),
|
}
|
||||||
'pe_lang_l14_448': _cfg(
|
)
|
||||||
hf_hub_id='timm/',
|
|
||||||
input_size=(3, 448, 448)),
|
|
||||||
'pe_lang_G14_448': _cfg(
|
|
||||||
hf_hub_id='timm/',
|
|
||||||
input_size=(3, 448, 448)),
|
|
||||||
'pe_spatial_G14_448': _cfg(
|
|
||||||
hf_hub_id='timm/',
|
|
||||||
input_size=(3, 448, 448)),
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
|
def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
|
||||||
out_indices = kwargs.pop('out_indices', 3)
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
PE,
|
PE,
|
||||||
variant,
|
variant,
|
||||||
@ -851,8 +794,9 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def pe_core_b16_224(pretrained=False, **kwargs):
|
def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
image_size=224,
|
image_size=224,
|
||||||
patch_size=16,
|
patch_size=16,
|
||||||
@ -864,10 +808,11 @@ def pe_core_b16_224(pretrained=False, **kwargs):
|
|||||||
use_cls_token=True,
|
use_cls_token=True,
|
||||||
pool_type='attn',
|
pool_type='attn',
|
||||||
)
|
)
|
||||||
return _create_pe('pe_core_b16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def pe_core_l14_336(pretrained=False, **kwargs):
|
def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
image_size=336,
|
image_size=336,
|
||||||
patch_size=14,
|
patch_size=14,
|
||||||
@ -879,11 +824,11 @@ def pe_core_l14_336(pretrained=False, **kwargs):
|
|||||||
use_cls_token=True,
|
use_cls_token=True,
|
||||||
pool_type='attn',
|
pool_type='attn',
|
||||||
)
|
)
|
||||||
return _create_pe('pe_core_l14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def pe_core_G14_448(pretrained=False, **kwargs):
|
def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
image_size=448,
|
image_size=448,
|
||||||
patch_size=14,
|
patch_size=14,
|
||||||
@ -895,27 +840,11 @@ def pe_core_G14_448(pretrained=False, **kwargs):
|
|||||||
use_cls_token=False,
|
use_cls_token=False,
|
||||||
pool_type='attn',
|
pool_type='attn',
|
||||||
)
|
)
|
||||||
return _create_pe('pe_core_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def pe_lang_G14_448(pretrained=False, **kwargs):
|
def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
|
||||||
image_size = 448,
|
|
||||||
patch_size = 14,
|
|
||||||
width = 1536,
|
|
||||||
layers = 47,
|
|
||||||
heads = 16,
|
|
||||||
mlp_ratio = 8960 / 1536,
|
|
||||||
output_dim = None,
|
|
||||||
use_cls_token = False,
|
|
||||||
use_ln_post = False,
|
|
||||||
pool_type = 'none',
|
|
||||||
ls_init_value = 0.1,
|
|
||||||
)
|
|
||||||
return _create_pe('pe_lang_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def pe_lang_l14_448(pretrained=False, **kwargs):
|
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
image_size=448,
|
image_size=448,
|
||||||
patch_size=14,
|
patch_size=14,
|
||||||
@ -929,10 +858,29 @@ def pe_lang_l14_448(pretrained=False, **kwargs):
|
|||||||
pool_type='none',
|
pool_type='none',
|
||||||
ls_init_value=0.1,
|
ls_init_value=0.1,
|
||||||
)
|
)
|
||||||
return _create_pe('pe_lang_l14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_pe('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def pe_spatial_G14_448(pretrained=False, **kwargs):
|
def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
image_size=448,
|
||||||
|
patch_size=14,
|
||||||
|
width=1536,
|
||||||
|
layers=47,
|
||||||
|
heads=16,
|
||||||
|
mlp_ratio=8960 / 1536,
|
||||||
|
output_dim=None,
|
||||||
|
use_cls_token=False,
|
||||||
|
use_ln_post=False,
|
||||||
|
pool_type='none',
|
||||||
|
ls_init_value=0.1,
|
||||||
|
)
|
||||||
|
return _create_pe('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
image_size=448,
|
image_size=448,
|
||||||
patch_size=14,
|
patch_size=14,
|
||||||
@ -946,4 +894,4 @@ def pe_spatial_G14_448(pretrained=False, **kwargs):
|
|||||||
pool_type='none',
|
pool_type='none',
|
||||||
ls_init_value=0.1,
|
ls_init_value=0.1,
|
||||||
)
|
)
|
||||||
return _create_pe('pe_spatial_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user