mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2487 from huggingface/eva_pe_integration
Add EVA ViT based PE (Perceptual Encoder) impl
This commit is contained in:
commit
f14f6507ab
@ -1,6 +1,7 @@
|
|||||||
from .activations import *
|
from .activations import *
|
||||||
from .adaptive_avgmax_pool import \
|
from .adaptive_avgmax_pool import \
|
||||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||||
|
from .attention import Attention, AttentionRope
|
||||||
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
||||||
from .attention_pool import AttentionPoolLatent
|
from .attention_pool import AttentionPoolLatent
|
||||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||||
@ -41,6 +42,7 @@ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct,
|
|||||||
from .padding import get_padding, get_same_padding, pad_same
|
from .padding import get_padding, get_same_padding, pad_same
|
||||||
from .patch_dropout import PatchDropout
|
from .patch_dropout import PatchDropout
|
||||||
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
|
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
|
||||||
|
from .pool1d import global_pool_nlc
|
||||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
||||||
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
|
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
|
||||||
|
212
timm/layers/attention.py
Normal file
212
timm/layers/attention.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
from typing import Final, Optional, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .config import use_fused_attn
|
||||||
|
from .pos_embed_sincos import apply_rot_embed_cat
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""Standard Multi-head Self Attention module with QKV projection.
|
||||||
|
|
||||||
|
This module implements the standard multi-head attention mechanism used in transformers.
|
||||||
|
It supports both the fused attention implementation (scaled_dot_product_attention) for
|
||||||
|
efficiency when available, and a manual implementation otherwise. The module includes
|
||||||
|
options for QK normalization, attention dropout, and projection dropout.
|
||||||
|
"""
|
||||||
|
fused_attn: Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
|
attn_drop: float = 0.,
|
||||||
|
proj_drop: float = 0.,
|
||||||
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the Attention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim: Input dimension of the token embeddings
|
||||||
|
num_heads: Number of attention heads
|
||||||
|
qkv_bias: Whether to use bias in the query, key, value projections
|
||||||
|
qk_norm: Whether to apply normalization to query and key vectors
|
||||||
|
proj_bias: Whether to use bias in the output projection
|
||||||
|
attn_drop: Dropout rate applied to the attention weights
|
||||||
|
proj_drop: Dropout rate applied after the output projection
|
||||||
|
norm_layer: Normalization layer constructor for QK normalization if enabled
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn = attn + attn_mask
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionRope(nn.Module):
|
||||||
|
""" A Self Attention module with ROPE support.
|
||||||
|
|
||||||
|
Includes options for:
|
||||||
|
* QK normalization option
|
||||||
|
* Attention output (scale) normalization
|
||||||
|
* Fused or unfused QKV projection support
|
||||||
|
"""
|
||||||
|
fused_attn: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qkv_fused: bool = True,
|
||||||
|
num_prefix_tokens: int = 1,
|
||||||
|
attn_drop: float = 0.,
|
||||||
|
proj_drop: float = 0.,
|
||||||
|
attn_head_dim: Optional[int] = None,
|
||||||
|
norm_layer: Type[nn.Module] = None,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
scale_norm: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the Attention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim: Input dimension of the token embeddings
|
||||||
|
num_heads: Number of attention heads
|
||||||
|
qkv_bias: Whether to add a bias term to the query, key, and value projections
|
||||||
|
num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
|
||||||
|
should not have position embeddings applied
|
||||||
|
attn_drop: Dropout rate for attention weights
|
||||||
|
proj_drop: Dropout rate for the output projection
|
||||||
|
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
|
||||||
|
norm_layer: Normalization layer constructor to use for QK and scale normalization
|
||||||
|
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
|
||||||
|
scale_norm: Enable normalization (scaling) of attention output with norm_layer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if scale_norm or qk_norm:
|
||||||
|
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
if attn_head_dim is not None:
|
||||||
|
head_dim = attn_head_dim
|
||||||
|
attn_dim = head_dim * self.num_heads
|
||||||
|
self.scale = head_dim ** -0.5
|
||||||
|
self.num_prefix_tokens = num_prefix_tokens
|
||||||
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
|
if qkv_fused:
|
||||||
|
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
|
||||||
|
self.q_proj = self.k_proj = self.v_proj = None
|
||||||
|
else:
|
||||||
|
self.qkv = None
|
||||||
|
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
|
||||||
|
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
|
||||||
|
self.proj = nn.Linear(attn_dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
rope: Optional[torch.Tensor] = None,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""Forward pass for the attention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
|
||||||
|
rope: Rotary position embeddings tensor for position-aware attention
|
||||||
|
attn_mask: Optional attention mask to apply during attention computation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape (batch_size, sequence_length, embedding_dim)
|
||||||
|
"""
|
||||||
|
B, N, C = x.shape
|
||||||
|
|
||||||
|
if self.qkv is not None:
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
|
||||||
|
else:
|
||||||
|
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
|
||||||
|
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
||||||
|
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if rope is not None:
|
||||||
|
npt = self.num_prefix_tokens
|
||||||
|
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
|
||||||
|
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.to(torch.bool)
|
||||||
|
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -28,8 +28,8 @@ class AttentionPoolLatent(nn.Module):
|
|||||||
latent_dim: int = None,
|
latent_dim: int = None,
|
||||||
pos_embed: str = '',
|
pos_embed: str = '',
|
||||||
pool_type: str = 'token',
|
pool_type: str = 'token',
|
||||||
norm_layer: Optional[nn.Module] = None,
|
norm_layer: Optional[Type[nn.Module]] = None,
|
||||||
act_layer: Optional[nn.Module] = nn.GELU,
|
act_layer: Optional[Type[nn.Module]] = nn.GELU,
|
||||||
drop: float = 0.0,
|
drop: float = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
26
timm/layers/pool1d.py
Normal file
26
timm/layers/pool1d.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def global_pool_nlc(
|
||||||
|
x: torch.Tensor,
|
||||||
|
pool_type: str = 'token',
|
||||||
|
num_prefix_tokens: int = 1,
|
||||||
|
reduce_include_prefix: bool = False,
|
||||||
|
):
|
||||||
|
if not pool_type:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if pool_type == 'token':
|
||||||
|
x = x[:, 0] # class token
|
||||||
|
else:
|
||||||
|
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
|
||||||
|
if pool_type == 'avg':
|
||||||
|
x = x.mean(dim=1)
|
||||||
|
elif pool_type == 'avgmax':
|
||||||
|
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
|
||||||
|
elif pool_type == 'max':
|
||||||
|
x = x.amax(dim=1)
|
||||||
|
else:
|
||||||
|
assert not pool_type, f'Unknown pool type {pool_type}'
|
||||||
|
|
||||||
|
return x
|
@ -87,6 +87,8 @@ def build_fourier_pos_embed(
|
|||||||
include_grid: bool = False,
|
include_grid: bool = False,
|
||||||
in_pixels: bool = True,
|
in_pixels: bool = True,
|
||||||
ref_feat_shape: Optional[List[int]] = None,
|
ref_feat_shape: Optional[List[int]] = None,
|
||||||
|
grid_offset: float = 0.,
|
||||||
|
grid_indexing: str = 'ij',
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
@ -102,6 +104,8 @@ def build_fourier_pos_embed(
|
|||||||
include_grid: Include the spatial grid in output.
|
include_grid: Include the spatial grid in output.
|
||||||
in_pixels: Output in pixel freq.
|
in_pixels: Output in pixel freq.
|
||||||
ref_feat_shape: Reference feature shape for resize / fine-tune.
|
ref_feat_shape: Reference feature shape for resize / fine-tune.
|
||||||
|
grid_offset: Constant offset to add to grid for non-pixel freq.
|
||||||
|
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
|
||||||
dtype: Output dtype.
|
dtype: Output dtype.
|
||||||
device: Output device.
|
device: Output device.
|
||||||
|
|
||||||
@ -130,15 +134,21 @@ def build_fourier_pos_embed(
|
|||||||
dtype = bands.dtype
|
dtype = bands.dtype
|
||||||
|
|
||||||
if in_pixels:
|
if in_pixels:
|
||||||
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
|
t = [
|
||||||
|
torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32)
|
||||||
|
for s in feat_shape
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
|
t = [
|
||||||
|
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + grid_offset
|
||||||
|
for s in feat_shape
|
||||||
|
]
|
||||||
|
|
||||||
if ref_feat_shape is not None:
|
if ref_feat_shape is not None:
|
||||||
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
|
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
|
||||||
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
|
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
|
||||||
|
|
||||||
grid = torch.stack(ndgrid(t), dim=-1)
|
grid = torch.stack(torch.meshgrid(t, indexing=grid_indexing), dim=-1)
|
||||||
grid = grid.unsqueeze(-1)
|
grid = grid.unsqueeze(-1)
|
||||||
pos = grid * bands
|
pos = grid * bands
|
||||||
|
|
||||||
@ -229,6 +239,8 @@ def build_rotary_pos_embed(
|
|||||||
linear_bands: bool = False,
|
linear_bands: bool = False,
|
||||||
in_pixels: bool = True,
|
in_pixels: bool = True,
|
||||||
ref_feat_shape: Optional[List[int]] = None,
|
ref_feat_shape: Optional[List[int]] = None,
|
||||||
|
grid_offset: float = 0.,
|
||||||
|
grid_indexing: str = 'ij',
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
):
|
):
|
||||||
@ -242,6 +254,9 @@ def build_rotary_pos_embed(
|
|||||||
temperature: Temperature (inv freq) for non-pixel mode
|
temperature: Temperature (inv freq) for non-pixel mode
|
||||||
linear_bands: Linearly (instead of log) spaced bands for pixel mode
|
linear_bands: Linearly (instead of log) spaced bands for pixel mode
|
||||||
in_pixels: Pixel vs language (inv freq) mode.
|
in_pixels: Pixel vs language (inv freq) mode.
|
||||||
|
ref_feat_shape: Reference feature shape for resize / fine-tune.
|
||||||
|
grid_offset: Constant offset to add to grid for non-pixel freq.
|
||||||
|
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
|
||||||
dtype: Output dtype.
|
dtype: Output dtype.
|
||||||
device: Output device.
|
device: Output device.
|
||||||
|
|
||||||
@ -257,6 +272,8 @@ def build_rotary_pos_embed(
|
|||||||
linear_bands=linear_bands,
|
linear_bands=linear_bands,
|
||||||
in_pixels=in_pixels,
|
in_pixels=in_pixels,
|
||||||
ref_feat_shape=ref_feat_shape,
|
ref_feat_shape=ref_feat_shape,
|
||||||
|
grid_offset=grid_offset,
|
||||||
|
grid_indexing=grid_indexing,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
@ -289,6 +306,8 @@ class RotaryEmbedding(nn.Module):
|
|||||||
linear_bands: bool = False,
|
linear_bands: bool = False,
|
||||||
feat_shape: Optional[List[int]] = None,
|
feat_shape: Optional[List[int]] = None,
|
||||||
ref_feat_shape: Optional[List[int]] = None,
|
ref_feat_shape: Optional[List[int]] = None,
|
||||||
|
grid_offset: float = 0.,
|
||||||
|
grid_indexing: str = 'ij',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -297,6 +316,8 @@ class RotaryEmbedding(nn.Module):
|
|||||||
self.in_pixels = in_pixels
|
self.in_pixels = in_pixels
|
||||||
self.feat_shape = feat_shape
|
self.feat_shape = feat_shape
|
||||||
self.ref_feat_shape = ref_feat_shape
|
self.ref_feat_shape = ref_feat_shape
|
||||||
|
self.grid_offset = grid_offset
|
||||||
|
self.grid_indexing = grid_indexing
|
||||||
|
|
||||||
if feat_shape is None:
|
if feat_shape is None:
|
||||||
# only cache bands
|
# only cache bands
|
||||||
@ -328,6 +349,8 @@ class RotaryEmbedding(nn.Module):
|
|||||||
linear_bands=linear_bands,
|
linear_bands=linear_bands,
|
||||||
in_pixels=in_pixels,
|
in_pixels=in_pixels,
|
||||||
ref_feat_shape=self.ref_feat_shape,
|
ref_feat_shape=self.ref_feat_shape,
|
||||||
|
grid_offset=self.grid_offset,
|
||||||
|
grid_indexing=self.grid_indexing,
|
||||||
)
|
)
|
||||||
self.bands = None
|
self.bands = None
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
@ -349,6 +372,9 @@ class RotaryEmbedding(nn.Module):
|
|||||||
shape,
|
shape,
|
||||||
self.bands,
|
self.bands,
|
||||||
in_pixels=self.in_pixels,
|
in_pixels=self.in_pixels,
|
||||||
|
ref_feat_shape=self.ref_feat_shape,
|
||||||
|
grid_offset=self.grid_offset,
|
||||||
|
grid_indexing=self.grid_indexing,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.pos_embed_sin, self.pos_embed_cos
|
return self.pos_embed_sin, self.pos_embed_cos
|
||||||
@ -376,6 +402,8 @@ class RotaryEmbeddingCat(nn.Module):
|
|||||||
linear_bands: bool = False,
|
linear_bands: bool = False,
|
||||||
feat_shape: Optional[List[int]] = None,
|
feat_shape: Optional[List[int]] = None,
|
||||||
ref_feat_shape: Optional[List[int]] = None,
|
ref_feat_shape: Optional[List[int]] = None,
|
||||||
|
grid_offset: float = 0.,
|
||||||
|
grid_indexing: str = 'ij',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -384,6 +412,8 @@ class RotaryEmbeddingCat(nn.Module):
|
|||||||
self.in_pixels = in_pixels
|
self.in_pixels = in_pixels
|
||||||
self.feat_shape = feat_shape
|
self.feat_shape = feat_shape
|
||||||
self.ref_feat_shape = ref_feat_shape
|
self.ref_feat_shape = ref_feat_shape
|
||||||
|
self.grid_offset = grid_offset
|
||||||
|
self.grid_indexing = grid_indexing
|
||||||
|
|
||||||
if feat_shape is None:
|
if feat_shape is None:
|
||||||
# only cache bands
|
# only cache bands
|
||||||
@ -414,6 +444,8 @@ class RotaryEmbeddingCat(nn.Module):
|
|||||||
linear_bands=linear_bands,
|
linear_bands=linear_bands,
|
||||||
in_pixels=in_pixels,
|
in_pixels=in_pixels,
|
||||||
ref_feat_shape=self.ref_feat_shape,
|
ref_feat_shape=self.ref_feat_shape,
|
||||||
|
grid_offset=self.grid_offset,
|
||||||
|
grid_indexing=self.grid_indexing,
|
||||||
)
|
)
|
||||||
self.bands = None
|
self.bands = None
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
@ -430,6 +462,8 @@ class RotaryEmbeddingCat(nn.Module):
|
|||||||
self.bands,
|
self.bands,
|
||||||
in_pixels=self.in_pixels,
|
in_pixels=self.in_pixels,
|
||||||
ref_feat_shape=self.ref_feat_shape,
|
ref_feat_shape=self.ref_feat_shape,
|
||||||
|
grid_offset=self.grid_offset,
|
||||||
|
grid_indexing=self.grid_indexing,
|
||||||
)
|
)
|
||||||
return torch.cat(embeds, -1)
|
return torch.cat(embeds, -1)
|
||||||
elif self.pos_embed is not None:
|
elif self.pos_embed is not None:
|
||||||
|
@ -25,6 +25,7 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
|
|||||||
# EVA models Copyright (c) 2022 BAAI-Vision
|
# EVA models Copyright (c) 2022 BAAI-Vision
|
||||||
# EVA02 models Copyright (c) 2023 BAAI-Vision
|
# EVA02 models Copyright (c) 2023 BAAI-Vision
|
||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -34,7 +35,7 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
|
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
|
||||||
apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \
|
apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \
|
||||||
to_2tuple, use_fused_attn
|
global_pool_nlc, to_2tuple, use_fused_attn, AttentionRope, AttentionPoolLatent
|
||||||
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
@ -45,6 +46,8 @@ __all__ = ['Eva']
|
|||||||
|
|
||||||
|
|
||||||
class EvaAttention(nn.Module):
|
class EvaAttention(nn.Module):
|
||||||
|
""" EVA Attention with ROPE, no k-bias, and fused/unfused qkv options
|
||||||
|
"""
|
||||||
fused_attn: torch.jit.Final[bool]
|
fused_attn: torch.jit.Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -53,55 +56,64 @@ class EvaAttention(nn.Module):
|
|||||||
num_heads: int = 8,
|
num_heads: int = 8,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qkv_fused: bool = True,
|
qkv_fused: bool = True,
|
||||||
num_prefix_tokens: int = 1,
|
|
||||||
qkv_bias_separate: bool = False,
|
qkv_bias_separate: bool = False,
|
||||||
|
num_prefix_tokens: int = 1,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
attn_head_dim: Optional[int] = None,
|
attn_head_dim: Optional[int] = None,
|
||||||
norm_layer: Optional[Callable] = None,
|
norm_layer: Optional[Callable] = None,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
scale_norm: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim:
|
dim: Input dimension of the token embeddings
|
||||||
num_heads:
|
num_heads: Number of attention heads
|
||||||
qkv_bias:
|
qkv_bias: Whether to add a bias term to the query, key, and value projections
|
||||||
qkv_fused:
|
qkv_fused: Whether qkv projections are fused into one projection or separate
|
||||||
attn_drop:
|
qkv_bias_separate: Whether to apply bias to qkv as a separate addition or part of F.linear() call
|
||||||
proj_drop:
|
num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
|
||||||
attn_head_dim:
|
should not have position embeddings applied
|
||||||
norm_layer:
|
attn_drop: Dropout rate for attention weights
|
||||||
|
proj_drop: Dropout rate for the output projection
|
||||||
|
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
|
||||||
|
norm_layer: Normalization layer constructor to use for QK and scale normalization
|
||||||
|
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
|
||||||
|
scale_norm: Enable normalization (scaling) of attention output with norm_layer
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if scale_norm or qk_norm:
|
||||||
|
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
head_dim = dim // num_heads
|
head_dim = dim // num_heads
|
||||||
if attn_head_dim is not None:
|
if attn_head_dim is not None:
|
||||||
head_dim = attn_head_dim
|
head_dim = attn_head_dim
|
||||||
all_head_dim = head_dim * self.num_heads
|
attn_dim = head_dim * self.num_heads
|
||||||
self.scale = head_dim ** -0.5
|
self.scale = head_dim ** -0.5
|
||||||
self.num_prefix_tokens = num_prefix_tokens
|
self.num_prefix_tokens = num_prefix_tokens
|
||||||
self.fused_attn = use_fused_attn()
|
self.fused_attn = use_fused_attn()
|
||||||
self.qkv_bias_separate = qkv_bias_separate
|
self.qkv_bias_separate = qkv_bias_separate
|
||||||
|
|
||||||
if qkv_fused:
|
if qkv_fused:
|
||||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
self.qkv = nn.Linear(dim, attn_dim * 3, bias=False)
|
||||||
self.q_proj = self.k_proj = self.v_proj = None
|
self.q_proj = self.k_proj = self.v_proj = None
|
||||||
if qkv_bias:
|
if qkv_bias:
|
||||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
self.q_bias = nn.Parameter(torch.zeros(attn_dim))
|
||||||
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
|
self.register_buffer('k_bias', torch.zeros(attn_dim), persistent=False)
|
||||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
self.v_bias = nn.Parameter(torch.zeros(attn_dim))
|
||||||
else:
|
else:
|
||||||
self.q_bias = self.k_bias = self.v_bias = None
|
self.q_bias = self.k_bias = self.v_bias = None
|
||||||
else:
|
else:
|
||||||
self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
|
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
|
||||||
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
self.k_proj = nn.Linear(dim, attn_dim, bias=False)
|
||||||
self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
|
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
|
||||||
self.qkv = None
|
self.qkv = None
|
||||||
self.q_bias = self.k_bias = self.v_bias = None
|
self.q_bias = self.k_bias = self.v_bias = None
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
self.norm = norm_layer(all_head_dim) if norm_layer is not None else nn.Identity()
|
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
|
||||||
self.proj = nn.Linear(all_head_dim, dim)
|
self.proj = nn.Linear(attn_dim, dim)
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -110,6 +122,16 @@ class EvaAttention(nn.Module):
|
|||||||
rope: Optional[torch.Tensor] = None,
|
rope: Optional[torch.Tensor] = None,
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
"""Forward pass for the attention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
|
||||||
|
rope: Rotary position embeddings tensor for position-aware attention
|
||||||
|
attn_mask: Optional attention mask to apply during attention computation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape (batch_size, sequence_length, embedding_dim)
|
||||||
|
"""
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
|
|
||||||
if self.qkv is not None:
|
if self.qkv is not None:
|
||||||
@ -129,6 +151,8 @@ class EvaAttention(nn.Module):
|
|||||||
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
||||||
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
if rope is not None:
|
if rope is not None:
|
||||||
npt = self.num_prefix_tokens
|
npt = self.num_prefix_tokens
|
||||||
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
|
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
|
||||||
@ -172,6 +196,7 @@ class EvaBlock(nn.Module):
|
|||||||
scale_mlp: bool = False,
|
scale_mlp: bool = False,
|
||||||
scale_attn_inner: bool = False,
|
scale_attn_inner: bool = False,
|
||||||
num_prefix_tokens: int = 1,
|
num_prefix_tokens: int = 1,
|
||||||
|
attn_type: str = 'eva',
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
drop_path: float = 0.,
|
drop_path: float = 0.,
|
||||||
@ -180,28 +205,31 @@ class EvaBlock(nn.Module):
|
|||||||
norm_layer: Callable = LayerNorm,
|
norm_layer: Callable = LayerNorm,
|
||||||
attn_head_dim: Optional[int] = None,
|
attn_head_dim: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
""" Initialize the EVA transformer block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim:
|
dim: Input dimension of the token embeddings
|
||||||
num_heads:
|
num_heads: Number of attention heads
|
||||||
qkv_bias:
|
qkv_bias: Whether to use bias terms in query, key, value projections
|
||||||
qkv_fused:
|
qkv_fused: Whether to use a single projection for query, key, value
|
||||||
mlp_ratio:
|
mlp_ratio: Ratio of MLP hidden dimension to input dimension
|
||||||
swiglu_mlp:
|
swiglu_mlp: Whether to use SwiGLU activation in the MLP
|
||||||
scale_mlp:
|
scale_mlp: Whether to use normalization in the MLP
|
||||||
scale_attn_inner:
|
scale_attn_inner: Whether to use normalization within the attention mechanism
|
||||||
proj_drop:
|
num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.)
|
||||||
attn_drop:
|
attn_type: Type of attention module to use ('eva' or 'rope')
|
||||||
drop_path:
|
proj_drop: Dropout rate for projection layers
|
||||||
init_values:
|
attn_drop: Dropout rate for attention matrix
|
||||||
act_layer:
|
drop_path: Stochastic depth rate
|
||||||
norm_layer:
|
init_values: Initial value for LayerScale, None = no LayerScale
|
||||||
attn_head_dim:
|
act_layer: Activation layer constructor
|
||||||
|
norm_layer: Normalization layer constructor
|
||||||
|
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.attn = EvaAttention(
|
attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
|
||||||
|
self.attn = attn_cls(
|
||||||
dim,
|
dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
@ -210,7 +238,8 @@ class EvaBlock(nn.Module):
|
|||||||
attn_drop=attn_drop,
|
attn_drop=attn_drop,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
attn_head_dim=attn_head_dim,
|
attn_head_dim=attn_head_dim,
|
||||||
norm_layer=norm_layer if scale_attn_inner else None,
|
norm_layer=norm_layer,
|
||||||
|
scale_norm=scale_attn_inner,
|
||||||
)
|
)
|
||||||
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
|
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
@ -266,6 +295,7 @@ class EvaBlockPostNorm(nn.Module):
|
|||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qkv_fused: bool = True,
|
qkv_fused: bool = True,
|
||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
|
attn_type: str = 'eva',
|
||||||
swiglu_mlp: bool = False,
|
swiglu_mlp: bool = False,
|
||||||
scale_mlp: bool = False,
|
scale_mlp: bool = False,
|
||||||
scale_attn_inner: bool = False,
|
scale_attn_inner: bool = False,
|
||||||
@ -278,27 +308,30 @@ class EvaBlockPostNorm(nn.Module):
|
|||||||
norm_layer: Callable = nn.LayerNorm,
|
norm_layer: Callable = nn.LayerNorm,
|
||||||
attn_head_dim: Optional[int] = None,
|
attn_head_dim: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
""" Initialize the post-norm EVA transformer block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim:
|
dim: Input dimension of the token embeddings
|
||||||
num_heads:
|
num_heads: Number of attention heads
|
||||||
qkv_bias:
|
qkv_bias: Whether to use bias terms in query, key, value projections
|
||||||
qkv_fused:
|
qkv_fused: Whether to use a single projection for query, key, value
|
||||||
mlp_ratio:
|
mlp_ratio: Ratio of MLP hidden dimension to input dimension
|
||||||
swiglu_mlp:
|
swiglu_mlp: Whether to use SwiGLU activation in the MLP
|
||||||
scale_mlp:
|
scale_mlp: Whether to use normalization in the MLP
|
||||||
scale_attn_inner:
|
scale_attn_inner: Whether to use normalization within the attention mechanism
|
||||||
proj_drop:
|
num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.)
|
||||||
attn_drop:
|
attn_type: Type of attention module to use ('eva' or 'rope')
|
||||||
drop_path:
|
proj_drop: Dropout rate for projection layers
|
||||||
init_values:
|
attn_drop: Dropout rate for attention matrix
|
||||||
act_layer:
|
drop_path: Stochastic depth rate
|
||||||
norm_layer:
|
init_values: Initial value for LayerScale, None = no LayerScale (NOTE: ignored for post-norm block)
|
||||||
attn_head_dim:
|
act_layer: Activation layer constructor
|
||||||
|
norm_layer: Normalization layer constructor
|
||||||
|
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = EvaAttention(
|
attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
|
||||||
|
self.attn = attn_cls(
|
||||||
dim,
|
dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
@ -307,7 +340,8 @@ class EvaBlockPostNorm(nn.Module):
|
|||||||
attn_drop=attn_drop,
|
attn_drop=attn_drop,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
attn_head_dim=attn_head_dim,
|
attn_head_dim=attn_head_dim,
|
||||||
norm_layer=norm_layer if scale_attn_inner else None,
|
norm_layer=norm_layer,
|
||||||
|
scale_norm=scale_attn_inner,
|
||||||
)
|
)
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
@ -373,6 +407,7 @@ class Eva(nn.Module):
|
|||||||
swiglu_mlp: bool = False,
|
swiglu_mlp: bool = False,
|
||||||
scale_mlp: bool = False,
|
scale_mlp: bool = False,
|
||||||
scale_attn_inner: bool = False,
|
scale_attn_inner: bool = False,
|
||||||
|
attn_type: str = 'eva',
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
pos_drop_rate: float = 0.,
|
pos_drop_rate: float = 0.,
|
||||||
patch_drop_rate: float = 0.,
|
patch_drop_rate: float = 0.,
|
||||||
@ -385,44 +420,64 @@ class Eva(nn.Module):
|
|||||||
num_reg_tokens: int = 0,
|
num_reg_tokens: int = 0,
|
||||||
use_abs_pos_emb: bool = True,
|
use_abs_pos_emb: bool = True,
|
||||||
use_rot_pos_emb: bool = False,
|
use_rot_pos_emb: bool = False,
|
||||||
|
rope_grid_offset: float = 0.,
|
||||||
|
rope_grid_indexing: str = 'ij',
|
||||||
use_post_norm: bool = False,
|
use_post_norm: bool = False,
|
||||||
|
use_pre_transformer_norm: bool = False,
|
||||||
|
use_post_transformer_norm: Optional[bool] = None,
|
||||||
|
use_fc_norm: Optional[bool] = None,
|
||||||
|
attn_pool_num_heads: Optional[int] = None,
|
||||||
|
attn_pool_mlp_ratio: Optional[float] = None,
|
||||||
dynamic_img_size: bool = False,
|
dynamic_img_size: bool = False,
|
||||||
dynamic_img_pad: bool = False,
|
dynamic_img_pad: bool = False,
|
||||||
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
||||||
head_init_scale: float = 0.001,
|
head_init_scale: float = 0.001,
|
||||||
):
|
):
|
||||||
"""
|
"""Initialize the EVA Vision Transformer model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_size:
|
img_size: Input image size (single int for square, or tuple for rectangular)
|
||||||
patch_size:
|
patch_size: Patch size to divide image into tokens (single int for square, or tuple)
|
||||||
in_chans:
|
in_chans: Number of input image channels
|
||||||
num_classes:
|
num_classes: Number of classes (output dim) for classification head (final projection), 0 for pass-through
|
||||||
global_pool:
|
global_pool: Type of global pooling for final sequence ('avg', 'token', 'map', etc.)
|
||||||
embed_dim:
|
embed_dim: Embedding dimension for tokens
|
||||||
depth:
|
depth: Number of transformer blocks
|
||||||
num_heads:
|
num_heads: Number of attention heads
|
||||||
qkv_bias:
|
qkv_bias: Enable bias for query, key, value projections
|
||||||
qkv_fused:
|
qkv_fused: Use a single projection for query, key, value
|
||||||
mlp_ratio:
|
mlp_ratio: Ratio of mlp hidden dim to embedding dim
|
||||||
swiglu_mlp:
|
swiglu_mlp: Use SwiGLU activation in MLP
|
||||||
scale_mlp:
|
scale_mlp: Apply scaling normalization in MLP (normformer style)
|
||||||
scale_attn_inner:
|
scale_attn_inner: Apply scaling normalization inside attention
|
||||||
drop_rate:
|
attn_type: Type of attention module to use
|
||||||
pos_drop_rate:
|
drop_rate: Dropout rate after final projection and pooling
|
||||||
proj_drop_rate:
|
pos_drop_rate: Dropout rate for positional embeddings
|
||||||
attn_drop_rate:
|
patch_drop_rate: Rate of dropping patches during training
|
||||||
drop_path_rate:
|
proj_drop_rate: Dropout rate for projections
|
||||||
norm_layer:
|
attn_drop_rate: Dropout rate for attention
|
||||||
init_values:
|
drop_path_rate: Stochastic depth rate
|
||||||
class_token:
|
norm_layer: Normalization layer constructor
|
||||||
use_abs_pos_emb:
|
init_values: Initial layer-scale values
|
||||||
use_rot_pos_emb:
|
class_token: Use class token
|
||||||
use_post_norm:
|
num_reg_tokens: Number of additional learnable 'register' tokens to add to the sequence
|
||||||
ref_feat_shape:
|
use_abs_pos_emb: Use absolute (learned) positional embeddings
|
||||||
head_init_scale:
|
use_rot_pos_emb: Use rotary position embeddings
|
||||||
|
rope_grid_offset: Offset for rotary position embedding grid
|
||||||
|
rope_grid_indexing: Indexing mode for rotary position embeddings ('ij' or 'xy')
|
||||||
|
use_post_norm: Use post-norm transformer block type
|
||||||
|
use_pre_transformer_norm: Use normalization layer before transformer blocks
|
||||||
|
use_post_transformer_norm: Use normalization layer after transformer blocks
|
||||||
|
use_fc_norm: Use normalization layer after pooling, before final classifier
|
||||||
|
attn_pool_num_heads: Number of heads in attention pooling
|
||||||
|
attn_pool_mlp_ratio: MLP ratio in attention pooling
|
||||||
|
dynamic_img_size: Support dynamic image sizes in forward pass
|
||||||
|
dynamic_img_pad: Apply dynamic padding for irregular image sizes
|
||||||
|
ref_feat_shape: Reference feature shape for rotary position embedding scale
|
||||||
|
head_init_scale: Initialization scale for classification head weights
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
|
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
|
||||||
@ -430,6 +485,17 @@ class Eva(nn.Module):
|
|||||||
self.dynamic_img_size = dynamic_img_size
|
self.dynamic_img_size = dynamic_img_size
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
# resolve norm / pool usage
|
||||||
|
activate_pre_norm = use_pre_transformer_norm
|
||||||
|
if use_fc_norm is not None:
|
||||||
|
activate_fc_norm = use_fc_norm # pass through if explicit
|
||||||
|
else:
|
||||||
|
activate_fc_norm = global_pool == 'avg' # default on if avg pool used
|
||||||
|
if use_post_transformer_norm is not None:
|
||||||
|
activate_post_norm = use_post_transformer_norm # pass through if explicit
|
||||||
|
else:
|
||||||
|
activate_post_norm = not activate_fc_norm # default on if fc_norm isn't active
|
||||||
|
|
||||||
embed_args = {}
|
embed_args = {}
|
||||||
if dynamic_img_size:
|
if dynamic_img_size:
|
||||||
# flatten deferred until after pos embed
|
# flatten deferred until after pos embed
|
||||||
@ -440,6 +506,7 @@ class Eva(nn.Module):
|
|||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
dynamic_img_pad=dynamic_img_pad,
|
dynamic_img_pad=dynamic_img_pad,
|
||||||
|
bias=not use_pre_transformer_norm,
|
||||||
**embed_args,
|
**embed_args,
|
||||||
)
|
)
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
@ -468,10 +535,14 @@ class Eva(nn.Module):
|
|||||||
in_pixels=False,
|
in_pixels=False,
|
||||||
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
|
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
|
||||||
ref_feat_shape=ref_feat_shape,
|
ref_feat_shape=ref_feat_shape,
|
||||||
|
grid_offset=rope_grid_offset,
|
||||||
|
grid_indexing=rope_grid_indexing,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.rope = None
|
self.rope = None
|
||||||
|
|
||||||
|
self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity()
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
|
block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
@ -484,6 +555,7 @@ class Eva(nn.Module):
|
|||||||
swiglu_mlp=swiglu_mlp,
|
swiglu_mlp=swiglu_mlp,
|
||||||
scale_mlp=scale_mlp,
|
scale_mlp=scale_mlp,
|
||||||
scale_attn_inner=scale_attn_inner,
|
scale_attn_inner=scale_attn_inner,
|
||||||
|
attn_type=attn_type,
|
||||||
num_prefix_tokens=self.num_prefix_tokens,
|
num_prefix_tokens=self.num_prefix_tokens,
|
||||||
proj_drop=proj_drop_rate,
|
proj_drop=proj_drop_rate,
|
||||||
attn_drop=attn_drop_rate,
|
attn_drop=attn_drop_rate,
|
||||||
@ -495,9 +567,21 @@ class Eva(nn.Module):
|
|||||||
self.feature_info = [
|
self.feature_info = [
|
||||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||||
|
|
||||||
use_fc_norm = self.global_pool == 'avg'
|
self.norm = norm_layer(embed_dim) if activate_post_norm else nn.Identity()
|
||||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
|
||||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
if global_pool == 'map':
|
||||||
|
attn_pool_num_heads = attn_pool_num_heads or num_heads
|
||||||
|
attn_pool_mlp_ratio = attn_pool_mlp_ratio or mlp_ratio
|
||||||
|
self.attn_pool = AttentionPoolLatent(
|
||||||
|
self.embed_dim,
|
||||||
|
num_heads=attn_pool_num_heads,
|
||||||
|
mlp_ratio=attn_pool_mlp_ratio,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.attn_pool = None
|
||||||
|
self.fc_norm = norm_layer(embed_dim) if activate_fc_norm else nn.Identity()
|
||||||
self.head_drop = nn.Dropout(drop_rate)
|
self.head_drop = nn.Dropout(drop_rate)
|
||||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
@ -626,6 +710,7 @@ class Eva(nn.Module):
|
|||||||
B, _, height, width = x.shape
|
B, _, height, width = x.shape
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x, rot_pos_embed = self._pos_embed(x)
|
x, rot_pos_embed = self._pos_embed(x)
|
||||||
|
x = self.norm_pre(x)
|
||||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
blocks = self.blocks
|
blocks = self.blocks
|
||||||
else:
|
else:
|
||||||
@ -668,13 +753,23 @@ class Eva(nn.Module):
|
|||||||
if prune_norm:
|
if prune_norm:
|
||||||
self.norm = nn.Identity()
|
self.norm = nn.Identity()
|
||||||
if prune_head:
|
if prune_head:
|
||||||
|
self.attn_pool = None
|
||||||
self.fc_norm = nn.Identity()
|
self.fc_norm = nn.Identity()
|
||||||
self.reset_classifier(0, '')
|
self.reset_classifier(0, '')
|
||||||
return take_indices
|
return take_indices
|
||||||
|
|
||||||
|
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
|
||||||
|
if self.attn_pool is not None:
|
||||||
|
x = self.attn_pool(x)
|
||||||
|
return x
|
||||||
|
pool_type = self.global_pool if pool_type is None else pool_type
|
||||||
|
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x, rot_pos_embed = self._pos_embed(x)
|
x, rot_pos_embed = self._pos_embed(x)
|
||||||
|
x = self.norm_pre(x)
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
x = checkpoint(blk, x, rope=rot_pos_embed)
|
x = checkpoint(blk, x, rope=rot_pos_embed)
|
||||||
@ -684,8 +779,7 @@ class Eva(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
if self.global_pool:
|
x = self.pool(x)
|
||||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
|
||||||
x = self.fc_norm(x)
|
x = self.fc_norm(x)
|
||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
return x if pre_logits else self.head(x)
|
return x if pre_logits else self.head(x)
|
||||||
@ -696,6 +790,67 @@ class Eva(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_pe(
|
||||||
|
state_dict,
|
||||||
|
model,
|
||||||
|
prefix: str = 'visual.',
|
||||||
|
):
|
||||||
|
""" Convert Perception Encoder weights """
|
||||||
|
state_dict = state_dict.get('model', state_dict)
|
||||||
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
out_dict = {}
|
||||||
|
swaps = [
|
||||||
|
('conv1', 'patch_embed.proj'),
|
||||||
|
('positional_embedding', 'pos_embed'),
|
||||||
|
('transformer.resblocks.', 'blocks.'),
|
||||||
|
('ln_pre', 'norm_pre'),
|
||||||
|
('ln_post', 'norm'),
|
||||||
|
('ln_', 'norm'),
|
||||||
|
('ls_1.gamma', 'gamma_1'),
|
||||||
|
('ls_2.gamma', 'gamma_2'),
|
||||||
|
('in_proj_', 'qkv.'),
|
||||||
|
('out_proj', 'proj'),
|
||||||
|
('mlp.c_fc', 'mlp.fc1'),
|
||||||
|
('mlp.c_proj', 'mlp.fc2'),
|
||||||
|
]
|
||||||
|
len_prefix = len(prefix)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if prefix:
|
||||||
|
if not k.startswith(prefix):
|
||||||
|
continue
|
||||||
|
k = k[len_prefix:]
|
||||||
|
|
||||||
|
for sp in swaps:
|
||||||
|
k = k.replace(sp[0], sp[1])
|
||||||
|
|
||||||
|
if k.startswith('attn_pool'):
|
||||||
|
k = k.replace('attn_pool.attn', 'attn_pool')
|
||||||
|
k = k.replace('attn_pool.layernorm', 'attn_pool.norm')
|
||||||
|
k = k.replace('attn_pool.probe', 'attn_pool.latent')
|
||||||
|
if k.startswith('attn_pool.qkv'):
|
||||||
|
dim = v.shape[0] // 3
|
||||||
|
if k.endswith('weight'):
|
||||||
|
out_dict['attn_pool.q.weight'] = v[:dim]
|
||||||
|
out_dict['attn_pool.kv.weight'] = v[dim:]
|
||||||
|
elif k.endswith('bias'):
|
||||||
|
out_dict['attn_pool.q.bias'] = v[:dim]
|
||||||
|
out_dict['attn_pool.kv.bias'] = v[dim:]
|
||||||
|
continue
|
||||||
|
elif k == 'proj':
|
||||||
|
k = 'head.weight'
|
||||||
|
v = v.transpose(0, 1)
|
||||||
|
out_dict['head.bias'] = torch.zeros(v.shape[0])
|
||||||
|
elif k == 'class_embedding':
|
||||||
|
k = 'cls_token'
|
||||||
|
v = v.unsqueeze(0).unsqueeze(1)
|
||||||
|
elif k == 'pos_embed':
|
||||||
|
v = v.unsqueeze(0)
|
||||||
|
out_dict[k] = v
|
||||||
|
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(
|
def checkpoint_filter_fn(
|
||||||
state_dict,
|
state_dict,
|
||||||
model,
|
model,
|
||||||
@ -708,6 +863,13 @@ def checkpoint_filter_fn(
|
|||||||
state_dict = state_dict.get('model', state_dict)
|
state_dict = state_dict.get('model', state_dict)
|
||||||
state_dict = state_dict.get('module', state_dict)
|
state_dict = state_dict.get('module', state_dict)
|
||||||
state_dict = state_dict.get('state_dict', state_dict)
|
state_dict = state_dict.get('state_dict', state_dict)
|
||||||
|
|
||||||
|
# Loading Meta PE (Perception Encoder) weights
|
||||||
|
if 'visual.conv1.weight' in state_dict:
|
||||||
|
return _convert_pe(state_dict, model)
|
||||||
|
elif 'conv1.weight' in state_dict:
|
||||||
|
return _convert_pe(state_dict, model, prefix='')
|
||||||
|
|
||||||
# prefix for loading OpenCLIP compatible weights
|
# prefix for loading OpenCLIP compatible weights
|
||||||
if 'visual.trunk.pos_embed' in state_dict:
|
if 'visual.trunk.pos_embed' in state_dict:
|
||||||
prefix = 'visual.trunk.'
|
prefix = 'visual.trunk.'
|
||||||
@ -721,10 +883,9 @@ def checkpoint_filter_fn(
|
|||||||
len_prefix = len(prefix)
|
len_prefix = len(prefix)
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if prefix:
|
if prefix:
|
||||||
if k.startswith(prefix):
|
if not k.startswith(prefix):
|
||||||
k = k[len_prefix:]
|
|
||||||
else:
|
|
||||||
continue
|
continue
|
||||||
|
k = k[len_prefix:]
|
||||||
|
|
||||||
if 'rope' in k:
|
if 'rope' in k:
|
||||||
# fixed embedding no need to load buffer from checkpoint
|
# fixed embedding no need to load buffer from checkpoint
|
||||||
@ -797,6 +958,17 @@ def _cfg(url='', **kwargs):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _pe_cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 0, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||||
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||||
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||||
|
'license': 'custom', **kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
default_cfgs = generate_default_cfgs({
|
||||||
|
|
||||||
# EVA 01 CLIP fine-tuned on imagenet-1k
|
# EVA 01 CLIP fine-tuned on imagenet-1k
|
||||||
@ -984,6 +1156,49 @@ default_cfgs = generate_default_cfgs({
|
|||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# Perception Encoder weights
|
||||||
|
'vit_pe_core_base_patch16_224': _pe_cfg(
|
||||||
|
#hf_hub_id='facebook/pe_core_base_patch16_224_timm',
|
||||||
|
hf_hub_id='facebook/PE-Core-B16-224',
|
||||||
|
hf_hub_filename='PE-Core-B16-224.pt',
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
num_classes=1024, # output proj dim
|
||||||
|
),
|
||||||
|
'vit_pe_core_large_patch14_336': _pe_cfg(
|
||||||
|
hf_hub_id='facebook/PE-Core-L14-336',
|
||||||
|
hf_hub_filename='PE-Core-L14-336.pt',
|
||||||
|
input_size=(3, 336, 336),
|
||||||
|
num_classes=1024, # output proj dim
|
||||||
|
),
|
||||||
|
'vit_pe_core_gigantic_patch14_448': _pe_cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
hf_hub_id='facebook/PE-Core-G14-448',
|
||||||
|
hf_hub_filename='PE-Core-G14-448.pt',
|
||||||
|
input_size=(3, 448, 448),
|
||||||
|
num_classes=1280, # output proj dim
|
||||||
|
),
|
||||||
|
'vit_pe_lang_large_patch14_448': _pe_cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
hf_hub_id='facebook/PE-Lang-L14-448',
|
||||||
|
hf_hub_filename='PE-Lang-L14-448.pt',
|
||||||
|
input_size=(3, 448, 448),
|
||||||
|
num_classes=0,
|
||||||
|
),
|
||||||
|
'vit_pe_lang_gigantic_patch14_448': _pe_cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
hf_hub_id='facebook/PE-Lang-G14-448',
|
||||||
|
hf_hub_filename='PE-Lang-G14-448.pt',
|
||||||
|
input_size=(3, 448, 448),
|
||||||
|
num_classes=0,
|
||||||
|
),
|
||||||
|
'vit_pe_spatial_gigantic_patch14_448': _pe_cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
hf_hub_id='facebook/PE-Spatial-G14-448',
|
||||||
|
hf_hub_filename='PE-Spatial-G14-448.pt',
|
||||||
|
input_size=(3, 448, 448),
|
||||||
|
num_classes=0,
|
||||||
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -1330,3 +1545,142 @@ def vit_base_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva:
|
|||||||
)
|
)
|
||||||
model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16,
|
||||||
|
embed_dim=768,
|
||||||
|
depth=12,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
global_pool='map',
|
||||||
|
attn_type='rope',
|
||||||
|
use_pre_transformer_norm=True,
|
||||||
|
use_rot_pos_emb=True,
|
||||||
|
ref_feat_shape=(14, 14),
|
||||||
|
rope_grid_offset=1.,
|
||||||
|
rope_grid_indexing='xy',
|
||||||
|
attn_pool_num_heads=8,
|
||||||
|
attn_pool_mlp_ratio=4.,
|
||||||
|
norm_layer=partial(LayerNorm, eps=1e-5),
|
||||||
|
#dynamic_img_size=True
|
||||||
|
)
|
||||||
|
return _create_eva('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14,
|
||||||
|
embed_dim=1024,
|
||||||
|
depth=24,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
global_pool='map',
|
||||||
|
attn_type='rope',
|
||||||
|
use_pre_transformer_norm=True,
|
||||||
|
use_rot_pos_emb=True,
|
||||||
|
ref_feat_shape=(24, 24),
|
||||||
|
rope_grid_offset=1.,
|
||||||
|
rope_grid_indexing='xy',
|
||||||
|
attn_pool_num_heads=8,
|
||||||
|
attn_pool_mlp_ratio=4.,
|
||||||
|
norm_layer=partial(LayerNorm, eps=1e-5),
|
||||||
|
#dynamic_img_size=True,
|
||||||
|
)
|
||||||
|
return _create_eva('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14,
|
||||||
|
embed_dim=1536,
|
||||||
|
depth=50,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=8960 / 1536,
|
||||||
|
global_pool='map',
|
||||||
|
attn_type='rope',
|
||||||
|
class_token=False,
|
||||||
|
use_pre_transformer_norm=True,
|
||||||
|
use_rot_pos_emb=True,
|
||||||
|
ref_feat_shape=(32, 32),
|
||||||
|
rope_grid_indexing='xy',
|
||||||
|
attn_pool_num_heads=8,
|
||||||
|
attn_pool_mlp_ratio=4.,
|
||||||
|
norm_layer=partial(LayerNorm, eps=1e-5),
|
||||||
|
#dynamic_img_size=True,
|
||||||
|
)
|
||||||
|
return _create_eva('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14,
|
||||||
|
embed_dim=1024,
|
||||||
|
depth=23,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
attn_type='rope',
|
||||||
|
class_token=True,
|
||||||
|
use_rot_pos_emb=True,
|
||||||
|
ref_feat_shape=(32, 32),
|
||||||
|
rope_grid_offset=1.,
|
||||||
|
rope_grid_indexing='xy',
|
||||||
|
use_pre_transformer_norm=True,
|
||||||
|
use_post_transformer_norm=False,
|
||||||
|
use_fc_norm=False, # explicitly disable
|
||||||
|
init_values=0.1,
|
||||||
|
norm_layer=partial(LayerNorm, eps=1e-5),
|
||||||
|
#dynamic_img_size=True,
|
||||||
|
)
|
||||||
|
return _create_eva('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14,
|
||||||
|
embed_dim=1536,
|
||||||
|
depth=47,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=8960 / 1536,
|
||||||
|
attn_type='rope',
|
||||||
|
class_token=False,
|
||||||
|
use_rot_pos_emb=True,
|
||||||
|
ref_feat_shape=(32, 32),
|
||||||
|
rope_grid_indexing='xy',
|
||||||
|
use_pre_transformer_norm=True,
|
||||||
|
use_post_transformer_norm=False,
|
||||||
|
use_fc_norm=False, # explicitly disable
|
||||||
|
init_values=0.1,
|
||||||
|
norm_layer=partial(LayerNorm, eps=1e-5),
|
||||||
|
#dynamic_img_size=True,
|
||||||
|
)
|
||||||
|
return _create_eva('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14,
|
||||||
|
embed_dim=1536,
|
||||||
|
depth=50,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=8960 / 1536,
|
||||||
|
attn_type='rope',
|
||||||
|
class_token=False,
|
||||||
|
use_rot_pos_emb=True,
|
||||||
|
ref_feat_shape=(32, 32),
|
||||||
|
rope_grid_indexing='xy',
|
||||||
|
use_pre_transformer_norm=True,
|
||||||
|
use_post_transformer_norm=False,
|
||||||
|
init_values=0.1,
|
||||||
|
norm_layer=partial(LayerNorm, eps=1e-5),
|
||||||
|
#dynamic_img_size=True,
|
||||||
|
)
|
||||||
|
return _create_eva('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user