mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Adding EVA02 weights and model defs, move beit based eva_giant to same eva.py file. Cleanup rotary pos, add lang oriented freq bands to be compat with eva design choice. Fix #1738
This commit is contained in:
parent
56b90317cd
commit
3863d63516
@ -27,7 +27,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible,
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, ConvMlp, GlobalResponseNormMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
||||
@ -37,8 +37,8 @@ from .patch_embed import PatchEmbed, resample_patch_embed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .pos_embed import resample_abs_pos_embed
|
||||
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
|
||||
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
|
||||
FourierEmbed, RotaryEmbedding
|
||||
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
|
||||
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||
|
@ -19,6 +19,7 @@ class Mlp(nn.Module):
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
use_conv=False,
|
||||
@ -33,6 +34,7 @@ class Mlp(nn.Module):
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
@ -55,9 +57,11 @@ class GluMlp(nn.Module):
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.Sigmoid,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
use_conv=False,
|
||||
gate_last=True,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
@ -67,10 +71,12 @@ class GluMlp(nn.Module):
|
||||
drop_probs = to_2tuple(drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
self.chunk_dim = 1 if use_conv else -1
|
||||
self.gate_last = gate_last # use second half of width for gate
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
@ -82,9 +88,57 @@ class GluMlp(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x, gates = x.chunk(2, dim=self.chunk_dim)
|
||||
x = x * self.act(gates)
|
||||
x1, x2 = x.chunk(2, dim=self.chunk_dim)
|
||||
x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
""" SwiGLU
|
||||
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
|
||||
better matches some other common impl which makes mapping checkpoints simpler.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.SiLU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = to_2tuple(bias)
|
||||
drop_probs = to_2tuple(drop)
|
||||
|
||||
self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])
|
||||
self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def init_weights(self):
|
||||
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||||
nn.init.ones_(self.fc1a.bias)
|
||||
nn.init.normal_(self.fc1a.weight, std=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x_gate = self.fc1_g(x)
|
||||
x = self.fc1_x(x)
|
||||
x = self.act(x_gate) * x
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
@ -99,6 +153,7 @@ class GatedMlp(nn.Module):
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
gate_layer=None,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
@ -118,6 +173,7 @@ class GatedMlp(nn.Module):
|
||||
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
|
||||
else:
|
||||
self.gate = nn.Identity()
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
@ -126,6 +182,7 @@ class GatedMlp(nn.Module):
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.gate(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
@ -23,15 +23,15 @@ def pixel_freq_bands(
|
||||
return bands * torch.pi
|
||||
|
||||
|
||||
def inv_freq_bands(
|
||||
def freq_bands(
|
||||
num_bands: int,
|
||||
temperature: float = 100000.,
|
||||
temperature: float = 10000.,
|
||||
step: int = 2,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
|
||||
return inv_freq
|
||||
bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
|
||||
return bands
|
||||
|
||||
|
||||
def build_sincos2d_pos_embed(
|
||||
@ -59,12 +59,12 @@ def build_sincos2d_pos_embed(
|
||||
"""
|
||||
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
|
||||
pos_dim = dim // 4
|
||||
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
|
||||
bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
|
||||
|
||||
if reverse_coord:
|
||||
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
||||
grid = torch.stack(
|
||||
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
|
||||
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
||||
# FIXME add support for unflattened spatial dim?
|
||||
|
||||
@ -78,18 +78,49 @@ def build_fourier_pos_embed(
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
num_bands: int = 64,
|
||||
max_res: int = 224,
|
||||
temperature: float = 10000.,
|
||||
linear_bands: bool = False,
|
||||
include_grid: bool = False,
|
||||
concat_out: bool = True,
|
||||
in_pixels: bool = True,
|
||||
ref_feat_shape: Optional[List[int]] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
feat_shape: Feature shape for embedding.
|
||||
bands: Pre-calculated frequency bands.
|
||||
num_bands: Number of frequency bands (determines output dim).
|
||||
max_res: Maximum resolution for pixel based freq.
|
||||
temperature: Temperature for non-pixel freq.
|
||||
linear_bands: Linear band spacing for pixel based freq.
|
||||
include_grid: Include the spatial grid in output.
|
||||
in_pixels: Output in pixel freq.
|
||||
ref_feat_shape: Reference feature shape for resize / fine-tune.
|
||||
dtype: Output dtype.
|
||||
device: Output device.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if bands is None:
|
||||
if in_pixels:
|
||||
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
|
||||
bands = pixel_freq_bands(
|
||||
num_bands,
|
||||
float(max_res),
|
||||
linear_bands=linear_bands,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
|
||||
bands = freq_bands(
|
||||
num_bands,
|
||||
temperature=temperature,
|
||||
step=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
if device is None:
|
||||
device = bands.device
|
||||
@ -97,31 +128,42 @@ def build_fourier_pos_embed(
|
||||
dtype = bands.dtype
|
||||
|
||||
if in_pixels:
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]
|
||||
else:
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||
t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]
|
||||
|
||||
if ref_feat_shape is not None:
|
||||
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
|
||||
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
|
||||
|
||||
grid = torch.stack(torch.meshgrid(t), dim=-1)
|
||||
grid = grid.unsqueeze(-1)
|
||||
pos = grid * bands
|
||||
|
||||
pos_sin, pos_cos = pos.sin(), pos.cos()
|
||||
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
|
||||
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
|
||||
if concat_out:
|
||||
out = torch.cat(out, dim=-1)
|
||||
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
|
||||
return out
|
||||
|
||||
|
||||
class FourierEmbed(nn.Module):
|
||||
|
||||
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
|
||||
def __init__(
|
||||
self,
|
||||
max_res: int = 224,
|
||||
num_bands: int = 64,
|
||||
concat_grid=True,
|
||||
keep_spatial=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_res = max_res
|
||||
self.num_bands = num_bands
|
||||
self.concat_grid = concat_grid
|
||||
self.keep_spatial = keep_spatial
|
||||
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
|
||||
self.register_buffer(
|
||||
'bands',
|
||||
pixel_freq_bands(max_res, num_bands),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C = x.shape[:2]
|
||||
@ -131,7 +173,9 @@ class FourierEmbed(nn.Module):
|
||||
self.bands,
|
||||
include_grid=self.concat_grid,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
device=x.device,
|
||||
)
|
||||
emb = torch.cat(emb, dim=-1)
|
||||
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
|
||||
batch_expand = (B,) + (-1,) * (x.ndim - 1)
|
||||
|
||||
@ -159,38 +203,57 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
||||
return [t * cos_emb + rot(t) * sin_emb for t in x]
|
||||
|
||||
|
||||
def apply_rot_embed_split(x: torch.Tensor, emb):
|
||||
split = emb.shape[-1] // 2
|
||||
return x * emb[:, :split] + rot(x) * emb[:, split:]
|
||||
def apply_rot_embed_cat(x: torch.Tensor, emb):
|
||||
sin_emb, cos_emb = emb.tensor_split(2, -1)
|
||||
return x * cos_emb + rot(x) * sin_emb
|
||||
|
||||
|
||||
def build_rotary_pos_embed(
|
||||
feat_shape: List[int],
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
dim: int = 64,
|
||||
max_freq: float = 224,
|
||||
max_res: int = 224,
|
||||
temperature: float = 10000.,
|
||||
linear_bands: bool = False,
|
||||
in_pixels: bool = True,
|
||||
ref_feat_shape: Optional[List[int]] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
NOTE: shape arg should include spatial dim only
|
||||
"""
|
||||
feat_shape = torch.Size(feat_shape)
|
||||
|
||||
Args:
|
||||
feat_shape: Spatial shape of the target tensor for embedding.
|
||||
bands: Optional pre-generated frequency bands
|
||||
dim: Output dimension of embedding tensor.
|
||||
max_res: Maximum resolution for pixel mode.
|
||||
temperature: Temperature (inv freq) for non-pixel mode
|
||||
linear_bands: Linearly (instead of log) spaced bands for pixel mode
|
||||
in_pixels: Pixel vs language (inv freq) mode.
|
||||
dtype: Output dtype.
|
||||
device: Output device.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
sin_emb, cos_emb = build_fourier_pos_embed(
|
||||
feat_shape,
|
||||
bands=bands,
|
||||
num_bands=dim // 4,
|
||||
max_res=max_freq,
|
||||
max_res=max_res,
|
||||
temperature=temperature,
|
||||
linear_bands=linear_bands,
|
||||
concat_out=False,
|
||||
in_pixels=in_pixels,
|
||||
ref_feat_shape=ref_feat_shape,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
N = feat_shape.numel()
|
||||
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||
num_spatial_dim = 1
|
||||
# this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
|
||||
for x in feat_shape:
|
||||
num_spatial_dim *= x
|
||||
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
|
||||
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
|
||||
return sin_emb, cos_emb
|
||||
|
||||
|
||||
@ -205,15 +268,164 @@ class RotaryEmbedding(nn.Module):
|
||||
* https://blog.eleuther.ai/rotary-embeddings/
|
||||
"""
|
||||
|
||||
def __init__(self, dim, max_res=224, linear_bands: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_res=224,
|
||||
temperature=10000,
|
||||
in_pixels=True,
|
||||
linear_bands: bool = False,
|
||||
feat_shape: Optional[List[int]] = None,
|
||||
ref_feat_shape: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
|
||||
self.max_res = max_res
|
||||
self.temperature = temperature
|
||||
self.in_pixels = in_pixels
|
||||
self.feat_shape = feat_shape
|
||||
self.ref_feat_shape = ref_feat_shape
|
||||
|
||||
def get_embed(self, shape: List[int]):
|
||||
return build_rotary_pos_embed(shape, self.bands)
|
||||
if feat_shape is None:
|
||||
# only cache bands
|
||||
if in_pixels:
|
||||
bands = pixel_freq_bands(
|
||||
dim // 4,
|
||||
float(max_res),
|
||||
linear_bands=linear_bands,
|
||||
)
|
||||
else:
|
||||
bands = freq_bands(
|
||||
dim // 4,
|
||||
temperature=temperature,
|
||||
step=1,
|
||||
)
|
||||
print(bands)
|
||||
self.register_buffer(
|
||||
'bands',
|
||||
bands,
|
||||
persistent=False,
|
||||
)
|
||||
self.pos_embed_sin = None
|
||||
self.pos_embed_cos = None
|
||||
else:
|
||||
# cache full sin/cos embeddings if shape provided up front
|
||||
emb_sin, emb_cos = build_rotary_pos_embed(
|
||||
feat_shape=feat_shape,
|
||||
dim=dim,
|
||||
max_res=max_res,
|
||||
linear_bands=linear_bands,
|
||||
in_pixels=in_pixels,
|
||||
ref_feat_shape=self.ref_feat_shape,
|
||||
)
|
||||
self.bands = None
|
||||
self.register_buffer(
|
||||
'pos_embed_sin',
|
||||
emb_sin,
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
'pos_embed_cos',
|
||||
emb_cos,
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def get_embed(self, shape: Optional[List[int]] = None):
|
||||
if self.bands is not None:
|
||||
# rebuild embeddings every call, use if target shape changes
|
||||
assert shape is not None
|
||||
return build_rotary_pos_embed(
|
||||
shape,
|
||||
self.bands,
|
||||
in_pixels=self.in_pixels,
|
||||
)
|
||||
else:
|
||||
return self.pos_embed_sin, self.pos_embed_cos
|
||||
|
||||
def forward(self, x):
|
||||
# assuming channel-first tensor where spatial dim are >= 2
|
||||
sin_emb, cos_emb = self.get_embed(x.shape[2:])
|
||||
return apply_rot_embed(x, sin_emb, cos_emb)
|
||||
|
||||
|
||||
class RotaryEmbeddingCat(nn.Module):
|
||||
""" Rotary position embedding w/ concatenatd sin & cos
|
||||
|
||||
The following impl/resources were referenced for this impl:
|
||||
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
||||
* https://blog.eleuther.ai/rotary-embeddings/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_res=224,
|
||||
temperature=10000,
|
||||
in_pixels=True,
|
||||
linear_bands: bool = False,
|
||||
feat_shape: Optional[List[int]] = None,
|
||||
ref_feat_shape: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_res = max_res
|
||||
self.temperature = temperature
|
||||
self.in_pixels = in_pixels
|
||||
self.feat_shape = feat_shape
|
||||
self.ref_feat_shape = ref_feat_shape
|
||||
|
||||
if feat_shape is None:
|
||||
# only cache bands
|
||||
if in_pixels:
|
||||
bands = pixel_freq_bands(
|
||||
dim // 4,
|
||||
float(max_res),
|
||||
linear_bands=linear_bands,
|
||||
)
|
||||
else:
|
||||
bands = freq_bands(
|
||||
dim // 4,
|
||||
temperature=temperature,
|
||||
step=1,
|
||||
)
|
||||
print(bands)
|
||||
self.register_buffer(
|
||||
'bands',
|
||||
bands,
|
||||
persistent=False,
|
||||
)
|
||||
self.embed = None
|
||||
else:
|
||||
# cache full sin/cos embeddings if shape provided up front
|
||||
embeds = build_rotary_pos_embed(
|
||||
feat_shape=feat_shape,
|
||||
dim=dim,
|
||||
max_res=max_res,
|
||||
linear_bands=linear_bands,
|
||||
in_pixels=in_pixels,
|
||||
ref_feat_shape=self.ref_feat_shape,
|
||||
)
|
||||
self.bands = None
|
||||
self.register_buffer(
|
||||
'pos_embed',
|
||||
torch.cat(embeds, -1),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def get_embed(self, shape: Optional[List[int]] = None):
|
||||
if self.bands is not None:
|
||||
# rebuild embeddings every call, use if target shape changes
|
||||
assert shape is not None
|
||||
embeds = build_rotary_pos_embed(
|
||||
shape,
|
||||
self.bands,
|
||||
in_pixels=self.in_pixels,
|
||||
)
|
||||
return torch.cat(embeds, -1)
|
||||
else:
|
||||
return self.pos_embed
|
||||
|
||||
def forward(self, x):
|
||||
# assuming channel-first tensor where spatial dim are >= 2
|
||||
pos_embed = self.get_embed(x.shape[2:])
|
||||
return apply_rot_embed_cat(x, pos_embed)
|
||||
|
@ -17,6 +17,7 @@ from .edgenext import *
|
||||
from .efficientformer import *
|
||||
from .efficientformer_v2 import *
|
||||
from .efficientnet import *
|
||||
from .eva import *
|
||||
from .focalnet import *
|
||||
from .gcvit import *
|
||||
from .ghostnet import *
|
||||
|
@ -21,17 +21,6 @@ archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
|
||||
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
|
||||
|
||||
@article{EVA,
|
||||
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
|
||||
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
|
||||
Tiejun and Wang, Xinlong and Cao, Yue},
|
||||
journal={arXiv preprint arXiv:2211.07636},
|
||||
year={2022}
|
||||
}
|
||||
|
||||
|
||||
At this point only the 1k fine-tuned classification weights and model configs have been added,
|
||||
see original source above for pre-training models and procedure.
|
||||
|
||||
@ -49,19 +38,18 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
|
||||
# EVA models Copyright (c) 2022 BAAI-Vision
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import checkpoint_filter_fn
|
||||
@ -93,8 +81,15 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, attn_drop=0.,
|
||||
proj_drop=0., window_size=None, attn_head_dim=None):
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
window_size: Optional[Tuple[int, int]] = None,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
@ -102,6 +97,7 @@ class Attention(nn.Module):
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
@ -142,20 +138,37 @@ class Attention(nn.Module):
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
if self.fast_attn:
|
||||
if self.relative_position_bias_table is not None:
|
||||
rel_pos_bias = self._get_rel_pos_bias()
|
||||
if shared_rel_pos_bias is not None:
|
||||
rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
|
||||
elif shared_rel_pos_bias is not None:
|
||||
rel_pos_bias = shared_rel_pos_bias
|
||||
else:
|
||||
rel_pos_bias = None
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
attn = attn + self._get_rel_pos_bias()
|
||||
if shared_rel_pos_bias is not None:
|
||||
attn = attn + shared_rel_pos_bias
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=rel_pos_bias,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
if self.relative_position_bias_table is not None:
|
||||
attn = attn + self._get_rel_pos_bias()
|
||||
if shared_rel_pos_bias is not None:
|
||||
attn = attn + shared_rel_pos_bias
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
@ -164,19 +177,53 @@ class Attention(nn.Module):
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
||||
window_size=None, attn_head_dim=None):
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
scale_mlp: bool = False,
|
||||
swiglu_mlp: bool = False,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
init_values: Optional[float] = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
window_size: Optional[Tuple[int, int]] = None,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
||||
window_size=window_size, attn_head_dim=attn_head_dim)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
window_size=window_size,
|
||||
attn_head_dim=attn_head_dim,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
if swiglu_mlp:
|
||||
self.mlp = SwiGLU(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
norm_layer=norm_layer if scale_mlp else None,
|
||||
drop=proj_drop,
|
||||
)
|
||||
else:
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer if scale_mlp else None,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
if init_values:
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
|
||||
@ -186,11 +233,11 @@ class Block(nn.Module):
|
||||
|
||||
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
|
||||
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@ -216,19 +263,42 @@ class Beit(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.,
|
||||
attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
||||
head_init_scale=0.001):
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
qkv_bias: bool = True,
|
||||
mlp_ratio: float = 4.,
|
||||
swiglu_mlp: bool = False,
|
||||
scale_mlp: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
init_values: Optional[float] = None,
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rel_pos_bias: bool = False,
|
||||
use_shared_rel_pos_bias: bool = False,
|
||||
head_init_scale: float = 0.001,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
@ -237,27 +307,41 @@ class Beit(nn.Module):
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads)
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
window_size=self.patch_embed.grid_size,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None)
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
mlp_ratio=mlp_ratio,
|
||||
scale_mlp=scale_mlp,
|
||||
swiglu_mlp=swiglu_mlp,
|
||||
proj_drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
# trunc_normal_(self.mask_token, std=.02)
|
||||
|
||||
self.fix_init_weight()
|
||||
if isinstance(self.head, nn.Linear):
|
||||
trunc_normal_(self.head.weight, std=.02)
|
||||
@ -328,11 +412,9 @@ class Beit(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.fc_norm is not None:
|
||||
x = x[:, 1:].mean(dim=1)
|
||||
x = self.fc_norm(x)
|
||||
else:
|
||||
x = x[:, 0]
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
@ -405,27 +487,6 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
||||
),
|
||||
|
||||
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
|
||||
),
|
||||
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
|
||||
})
|
||||
|
||||
|
||||
@ -509,30 +570,3 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
|
||||
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
|
||||
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva_giant_patch14_224(pretrained=False, **kwargs):
|
||||
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
||||
model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva_giant_patch14_336(pretrained=False, **kwargs):
|
||||
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
||||
model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva_giant_patch14_560(pretrained=False, **kwargs):
|
||||
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
||||
model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
730
timm/models/eva.py
Normal file
730
timm/models/eva.py
Normal file
@ -0,0 +1,730 @@
|
||||
""" EVA
|
||||
|
||||
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
|
||||
|
||||
@article{EVA,
|
||||
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
|
||||
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
|
||||
Tiejun and Wang, Xinlong and Cao, Yue},
|
||||
journal={arXiv preprint arXiv:2211.07636},
|
||||
year={2022}
|
||||
}
|
||||
|
||||
EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331
|
||||
@article{EVA02,
|
||||
title={EVA-02: A Visual Representation for Neon Genesis},
|
||||
author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue},
|
||||
journal={arXiv preprint arXiv:2303.11331},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
This file contains EVA & EVA02 model implementations evolved from BEiT, additional models in vision_transformer.py.
|
||||
|
||||
Modifications by / Copyright 2023 Ross Wightman, original copyrights below
|
||||
"""
|
||||
# EVA models Copyright (c) 2022 BAAI-Vision
|
||||
# EVA02 models Copyright (c) 2023 BAAI-Vision
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, RotaryEmbeddingCat, \
|
||||
apply_rot_embed_cat, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['Eva']
|
||||
|
||||
|
||||
class EvaAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
qkv_fused: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
|
||||
if qkv_fused:
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
self.q_proj = self.k_proj = self.v_proj = None
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = self.k_bias = self.v_bias = None
|
||||
else:
|
||||
self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
|
||||
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
|
||||
self.qkv = None
|
||||
self.q_bias = self.k_bias = self.v_bias = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
rope: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.qkv is not None:
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
|
||||
else:
|
||||
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
|
||||
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
||||
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
if rope is not None:
|
||||
q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
|
||||
k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
|
||||
|
||||
if self.fast_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = attn.softmax(dim=-1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class EvaBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool = True,
|
||||
qkv_fused: bool = True,
|
||||
mlp_ratio: float = 4.,
|
||||
scale_mlp: bool = False,
|
||||
swiglu_mlp: bool = False,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
init_values: Optional[float] = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = EvaAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qkv_fused=qkv_fused,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
attn_head_dim=attn_head_dim,
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
hidden_features = int(dim * mlp_ratio)
|
||||
if swiglu_mlp:
|
||||
if scale_mlp:
|
||||
# when norm in SwiGLU used, an impl with separate fc for gate & x is used
|
||||
self.mlp = SwiGLU(
|
||||
in_features=dim,
|
||||
hidden_features=hidden_features,
|
||||
norm_layer=norm_layer if scale_mlp else None,
|
||||
drop=proj_drop,
|
||||
)
|
||||
else:
|
||||
# w/o any extra norm, an impl with packed weights is used, matches existing GluMLP
|
||||
self.mlp = GluMlp(
|
||||
in_features=dim,
|
||||
hidden_features=hidden_features * 2,
|
||||
norm_layer=norm_layer if scale_mlp else None,
|
||||
act_layer=nn.SiLU,
|
||||
gate_last=False,
|
||||
drop=proj_drop,
|
||||
)
|
||||
else:
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=hidden_features,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer if scale_mlp else None,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
|
||||
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Eva(nn.Module):
|
||||
""" Eva Vision Transformer w/ Abs & Rotary Pos Embed
|
||||
|
||||
This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
|
||||
* EVA - abs pos embed, global avg pool
|
||||
* EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
qkv_bias: bool = True,
|
||||
qkv_fused: bool = True,
|
||||
mlp_ratio: float = 4.,
|
||||
swiglu_mlp: bool = False,
|
||||
scale_mlp: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
init_values: Optional[float] = None,
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rot_pos_emb: bool = False,
|
||||
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
||||
head_init_scale: float = 0.001,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_rot_pos_emb:
|
||||
ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
|
||||
self.rope = RotaryEmbeddingCat(
|
||||
embed_dim // num_heads,
|
||||
in_pixels=False,
|
||||
feat_shape=self.patch_embed.grid_size,
|
||||
ref_feat_shape=ref_feat_shape,
|
||||
)
|
||||
else:
|
||||
self.rope = None
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
EvaBlock(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qkv_fused=qkv_fused,
|
||||
mlp_ratio=mlp_ratio,
|
||||
scale_mlp=scale_mlp,
|
||||
swiglu_mlp=swiglu_mlp,
|
||||
proj_drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
self.fix_init_weight()
|
||||
if isinstance(self.head, nn.Linear):
|
||||
trunc_normal_(self.head.weight, std=.02)
|
||||
self.head.weight.data.mul_(head_init_scale)
|
||||
self.head.bias.data.mul_(head_init_scale)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
nwd = {'pos_embed', 'cls_token'}
|
||||
return nwd
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
|
||||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(blk, x, rope=rot_pos_embed)
|
||||
else:
|
||||
x = blk(x, rope=rot_pos_embed)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict,
|
||||
model,
|
||||
interpolation='bicubic',
|
||||
antialias=True,
|
||||
):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model_ema', state_dict)
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('module', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
no_qkv = 'blocks.0.attn.q_proj.weight' in state_dict
|
||||
mim_weights = 'mask_token' in state_dict
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if 'rope' in k:
|
||||
# fixed embedding no need to load buffer from checkpoint
|
||||
continue
|
||||
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
_, _, H, W = model.patch_embed.proj.weight.shape
|
||||
if v.shape[-1] != W or v.shape[-2] != H:
|
||||
v = resample_patch_embed(
|
||||
v,
|
||||
(H, W),
|
||||
interpolation=interpolation,
|
||||
antialias=antialias,
|
||||
verbose=True,
|
||||
)
|
||||
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||
# To resize pos embedding when using model at different size from pretrained weights
|
||||
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
|
||||
v = resample_abs_pos_embed(
|
||||
v,
|
||||
new_size=model.patch_embed.grid_size,
|
||||
num_prefix_tokens=num_prefix_tokens,
|
||||
interpolation=interpolation,
|
||||
antialias=antialias,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
k = k.replace('mlp.ffn_ln', 'mlp.norm')
|
||||
k = k.replace('mlp.w12', 'mlp.fc1')
|
||||
k = k.replace('mlp.w1', 'mlp.fc1_g')
|
||||
k = k.replace('mlp.w2', 'mlp.fc1_x')
|
||||
k = k.replace('mlp.w3', 'mlp.fc2')
|
||||
if no_qkv:
|
||||
k = k.replace('q_bias', 'q_proj.bias')
|
||||
k = k.replace('v_bias', 'v_proj.bias')
|
||||
|
||||
if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
|
||||
if k == 'norm.weight' or k == 'norm.bias':
|
||||
# try moving norm -> fc norm on fine-tune, probably a better starting point than new init
|
||||
k = k.replace('norm', 'fc_norm')
|
||||
else:
|
||||
# skip pretrain mask token & head weights
|
||||
continue
|
||||
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_eva(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Eva models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
Eva, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
|
||||
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
||||
),
|
||||
'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
||||
),
|
||||
'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
||||
),
|
||||
|
||||
'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
|
||||
input_size=(3, 336, 336), crop_pct=1.0,
|
||||
),
|
||||
'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
|
||||
input_size=(3, 336, 336), crop_pct=1.0,
|
||||
),
|
||||
'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
),
|
||||
'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
),
|
||||
'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
),
|
||||
|
||||
'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
||||
),
|
||||
'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
||||
),
|
||||
'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
||||
),
|
||||
|
||||
'eva02_tiny_patch14_224.mim_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
|
||||
num_classes=0,
|
||||
),
|
||||
'eva02_small_patch14_224.mim_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
|
||||
num_classes=0,
|
||||
),
|
||||
'eva02_base_patch14_224.mim_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
|
||||
num_classes=0,
|
||||
),
|
||||
'eva02_large_patch14_224.mim_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
|
||||
num_classes=0,
|
||||
),
|
||||
'eva02_large_patch14_224.mim_m38m': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def eva_giant_patch14_224(pretrained=False, **kwargs):
|
||||
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
||||
model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva_giant_patch14_336(pretrained=False, **kwargs):
|
||||
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
||||
model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva_giant_patch14_560(pretrained=False, **kwargs):
|
||||
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
|
||||
model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_tiny_patch14_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
num_heads=3,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
swiglu_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_small_patch14_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
swiglu_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_base_patch14_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
qkv_fused=False,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
scale_mlp=True,
|
||||
swiglu_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_large_patch14_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
qkv_fused=False,
|
||||
scale_mlp=True,
|
||||
swiglu_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_tiny_patch14_336(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
img_size=336,
|
||||
patch_size=14,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
num_heads=3,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
swiglu_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_small_patch14_336(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
img_size=336,
|
||||
patch_size=14,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
swiglu_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_base_patch14_448(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
img_size=448,
|
||||
patch_size=14,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
qkv_fused=False,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
scale_mlp=True,
|
||||
swiglu_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_large_patch14_448(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
img_size=448,
|
||||
patch_size=14,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
qkv_fused=False,
|
||||
scale_mlp=True,
|
||||
swiglu_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
Loading…
x
Reference in New Issue
Block a user