Adding EVA02 weights and model defs, move beit based eva_giant to same eva.py file. Cleanup rotary pos, add lang oriented freq bands to be compat with eva design choice. Fix #1738

This commit is contained in:
Ross Wightman 2023-03-27 17:14:58 -07:00
parent 56b90317cd
commit 3863d63516
6 changed files with 1181 additions and 147 deletions

View File

@ -27,7 +27,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible,
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
@ -37,8 +37,8 @@ from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
FourierEmbed, RotaryEmbedding
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct

View File

@ -19,6 +19,7 @@ class Mlp(nn.Module):
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
@ -33,6 +34,7 @@ class Mlp(nn.Module):
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
@ -55,9 +57,11 @@ class GluMlp(nn.Module):
hidden_features=None,
out_features=None,
act_layer=nn.Sigmoid,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
gate_last=True,
):
super().__init__()
out_features = out_features or in_features
@ -67,10 +71,12 @@ class GluMlp(nn.Module):
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.chunk_dim = 1 if use_conv else -1
self.gate_last = gate_last # use second half of width for gate
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
@ -82,9 +88,57 @@ class GluMlp(nn.Module):
def forward(self, x):
x = self.fc1(x)
x, gates = x.chunk(2, dim=self.chunk_dim)
x = x * self.act(gates)
x1, x2 = x.chunk(2, dim=self.chunk_dim)
x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SwiGLU(nn.Module):
""" SwiGLU
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
better matches some other common impl which makes mapping checkpoints simpler.
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.SiLU,
norm_layer=nn.LayerNorm,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
self.drop = nn.Dropout(drop)
def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
nn.init.ones_(self.fc1a.bias)
nn.init.normal_(self.fc1a.weight, std=1e-6)
def forward(self, x):
x_gate = self.fc1_g(x)
x = self.fc1_x(x)
x = self.act(x_gate) * x
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
@ -99,6 +153,7 @@ class GatedMlp(nn.Module):
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
gate_layer=None,
bias=True,
drop=0.,
@ -118,6 +173,7 @@ class GatedMlp(nn.Module):
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
else:
self.gate = nn.Identity()
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
@ -126,6 +182,7 @@ class GatedMlp(nn.Module):
x = self.act(x)
x = self.drop1(x)
x = self.gate(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x

View File

@ -23,15 +23,15 @@ def pixel_freq_bands(
return bands * torch.pi
def inv_freq_bands(
def freq_bands(
num_bands: int,
temperature: float = 100000.,
temperature: float = 10000.,
step: int = 2,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return inv_freq
bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return bands
def build_sincos2d_pos_embed(
@ -59,12 +59,12 @@ def build_sincos2d_pos_embed(
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
@ -78,18 +78,49 @@ def build_fourier_pos_embed(
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False,
include_grid: bool = False,
concat_out: bool = True,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
"""
Args:
feat_shape: Feature shape for embedding.
bands: Pre-calculated frequency bands.
num_bands: Number of frequency bands (determines output dim).
max_res: Maximum resolution for pixel based freq.
temperature: Temperature for non-pixel freq.
linear_bands: Linear band spacing for pixel based freq.
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
dtype: Output dtype.
device: Output device.
Returns:
"""
if bands is None:
if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
bands = pixel_freq_bands(
num_bands,
float(max_res),
linear_bands=linear_bands,
dtype=dtype,
device=device,
)
else:
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
bands = freq_bands(
num_bands,
temperature=temperature,
step=1,
dtype=dtype,
device=device,
)
else:
if device is None:
device = bands.device
@ -97,31 +128,42 @@ def build_fourier_pos_embed(
dtype = bands.dtype
if in_pixels:
grid = torch.stack(torch.meshgrid(
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]
else:
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]
if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
grid = torch.stack(torch.meshgrid(t), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos()
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
if concat_out:
out = torch.cat(out, dim=-1)
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
return out
class FourierEmbed(nn.Module):
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
def __init__(
self,
max_res: int = 224,
num_bands: int = 64,
concat_grid=True,
keep_spatial=False,
):
super().__init__()
self.max_res = max_res
self.num_bands = num_bands
self.concat_grid = concat_grid
self.keep_spatial = keep_spatial
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
self.register_buffer(
'bands',
pixel_freq_bands(max_res, num_bands),
persistent=False,
)
def forward(self, x):
B, C = x.shape[:2]
@ -131,7 +173,9 @@ class FourierEmbed(nn.Module):
self.bands,
include_grid=self.concat_grid,
dtype=x.dtype,
device=x.device)
device=x.device,
)
emb = torch.cat(emb, dim=-1)
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1)
@ -159,38 +203,57 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_split(x: torch.Tensor, emb):
split = emb.shape[-1] // 2
return x * emb[:, :split] + rot(x) * emb[:, split:]
def apply_rot_embed_cat(x: torch.Tensor, emb):
sin_emb, cos_emb = emb.tensor_split(2, -1)
return x * cos_emb + rot(x) * sin_emb
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_freq: float = 224,
max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
"""
NOTE: shape arg should include spatial dim only
"""
feat_shape = torch.Size(feat_shape)
Args:
feat_shape: Spatial shape of the target tensor for embedding.
bands: Optional pre-generated frequency bands
dim: Output dimension of embedding tensor.
max_res: Maximum resolution for pixel mode.
temperature: Temperature (inv freq) for non-pixel mode
linear_bands: Linearly (instead of log) spaced bands for pixel mode
in_pixels: Pixel vs language (inv freq) mode.
dtype: Output dtype.
device: Output device.
Returns:
"""
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape,
bands=bands,
num_bands=dim // 4,
max_res=max_freq,
max_res=max_res,
temperature=temperature,
linear_bands=linear_bands,
concat_out=False,
in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
device=device,
dtype=dtype,
)
N = feat_shape.numel()
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
num_spatial_dim = 1
# this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
for x in feat_shape:
num_spatial_dim *= x
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
@ -205,15 +268,164 @@ class RotaryEmbedding(nn.Module):
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(self, dim, max_res=224, linear_bands: bool = False):
def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__()
self.dim = dim
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
def get_embed(self, shape: List[int]):
return build_rotary_pos_embed(shape, self.bands)
if feat_shape is None:
# only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.pos_embed_sin = None
self.pos_embed_cos = None
else:
# cache full sin/cos embeddings if shape provided up front
emb_sin, emb_cos = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
self.bands = None
self.register_buffer(
'pos_embed_sin',
emb_sin,
persistent=False,
)
self.register_buffer(
'pos_embed_cos',
emb_cos,
persistent=False,
)
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
assert shape is not None
return build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
)
else:
return self.pos_embed_sin, self.pos_embed_cos
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)
class RotaryEmbeddingCat(nn.Module):
""" Rotary position embedding w/ concatenatd sin & cos
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__()
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
if feat_shape is None:
# only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.embed = None
else:
# cache full sin/cos embeddings if shape provided up front
embeds = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
self.bands = None
self.register_buffer(
'pos_embed',
torch.cat(embeds, -1),
persistent=False,
)
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
assert shape is not None
embeds = build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
)
return torch.cat(embeds, -1)
else:
return self.pos_embed
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
return apply_rot_embed_cat(x, pos_embed)

View File

@ -17,6 +17,7 @@ from .edgenext import *
from .efficientformer import *
from .efficientformer_v2 import *
from .efficientnet import *
from .eva import *
from .focalnet import *
from .gcvit import *
from .ghostnet import *

View File

@ -21,17 +21,6 @@ archivePrefix={arXiv},
primaryClass={cs.CV}
}
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
@article{EVA,
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
Tiejun and Wang, Xinlong and Cao, Yue},
journal={arXiv preprint arXiv:2211.07636},
year={2022}
}
At this point only the 1k fine-tuned classification weights and model configs have been added,
see original source above for pre-training models and procedure.
@ -49,19 +38,18 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
# EVA models Copyright (c) 2022 BAAI-Vision
import math
from functools import partial
from typing import Optional, Tuple
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_
from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import checkpoint_filter_fn
@ -93,8 +81,15 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
window_size: Optional[Tuple[int, int]] = None,
attn_head_dim: Optional[int] = None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
@ -102,6 +97,7 @@ class Attention(nn.Module):
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
@ -142,20 +138,37 @@ class Attention(nn.Module):
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.fast_attn:
if self.relative_position_bias_table is not None:
rel_pos_bias = self._get_rel_pos_bias()
if shared_rel_pos_bias is not None:
rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
elif shared_rel_pos_bias is not None:
rel_pos_bias = shared_rel_pos_bias
else:
rel_pos_bias = None
if self.relative_position_bias_table is not None:
attn = attn + self._get_rel_pos_bias()
if shared_rel_pos_bias is not None:
attn = attn + shared_rel_pos_bias
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=rel_pos_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if self.relative_position_bias_table is not None:
attn = attn + self._get_rel_pos_bias()
if shared_rel_pos_bias is not None:
attn = attn + shared_rel_pos_bias
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -164,19 +177,53 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
self,
dim: int,
num_heads: int,
qkv_bias: bool = False,
mlp_ratio: float = 4.,
scale_mlp: bool = False,
swiglu_mlp: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
init_values: Optional[float] = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
window_size: Optional[Tuple[int, int]] = None,
attn_head_dim: Optional[int] = None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
window_size=window_size, attn_head_dim=attn_head_dim)
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
window_size=window_size,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if swiglu_mlp:
self.mlp = SwiGLU(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if init_values:
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
@ -186,11 +233,11 @@ class Block(nn.Module):
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path2(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
return x
@ -216,19 +263,42 @@ class Beit(nn.Module):
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
head_init_scale=0.001):
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
qkv_bias: bool = True,
mlp_ratio: float = 4.,
swiglu_mlp: bool = False,
scale_mlp: bool = False,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_layer: Callable = LayerNorm,
init_values: Optional[float] = None,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = False,
head_init_scale: float = 0.001,
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@ -237,27 +307,41 @@ class Beit(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads)
self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.grid_size,
num_heads=num_heads,
)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None)
dim=embed_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
scale_mlp=scale_mlp,
swiglu_mlp=swiglu_mlp,
proj_drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
)
for i in range(depth)])
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
@ -328,11 +412,9 @@ class Beit(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
if self.fc_norm is not None:
x = x[:, 1:].mean(dim=1)
x = self.fc_norm(x)
else:
x = x[:, 0]
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
def forward(self, x):
@ -405,27 +487,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
})
@ -509,30 +570,3 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_224(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_336(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_560(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs)
return model

730
timm/models/eva.py Normal file
View File

@ -0,0 +1,730 @@
""" EVA
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
@article{EVA,
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
Tiejun and Wang, Xinlong and Cao, Yue},
journal={arXiv preprint arXiv:2211.07636},
year={2022}
}
EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331
@article{EVA02,
title={EVA-02: A Visual Representation for Neon Genesis},
author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue},
journal={arXiv preprint arXiv:2303.11331},
year={2023}
}
This file contains EVA & EVA02 model implementations evolved from BEiT, additional models in vision_transformer.py.
Modifications by / Copyright 2023 Ross Wightman, original copyrights below
"""
# EVA models Copyright (c) 2022 BAAI-Vision
# EVA02 models Copyright (c) 2023 BAAI-Vision
import math
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, RotaryEmbeddingCat, \
apply_rot_embed_cat, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple
from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model
__all__ = ['Eva']
class EvaAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qkv_fused: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
attn_head_dim: Optional[int] = None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if qkv_fused:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
self.q_proj = self.k_proj = self.v_proj = None
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = self.k_bias = self.v_bias = None
else:
self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
self.qkv = None
self.q_bias = self.k_bias = self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x,
rope: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
B, N, C = x.shape
if self.qkv is not None:
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
else:
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
if rope is not None:
q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
if self.fast_attn:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
if attn_mask is not None:
attn_mask = attn_mask.to(torch.bool)
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class EvaBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
qkv_bias: bool = True,
qkv_fused: bool = True,
mlp_ratio: float = 4.,
scale_mlp: bool = False,
swiglu_mlp: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
init_values: Optional[float] = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
attn_head_dim: Optional[int] = None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = EvaAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qkv_fused=qkv_fused,
attn_drop=attn_drop,
proj_drop=proj_drop,
attn_head_dim=attn_head_dim,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
hidden_features = int(dim * mlp_ratio)
if swiglu_mlp:
if scale_mlp:
# when norm in SwiGLU used, an impl with separate fc for gate & x is used
self.mlp = SwiGLU(
in_features=dim,
hidden_features=hidden_features,
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
else:
# w/o any extra norm, an impl with packed weights is used, matches existing GluMLP
self.mlp = GluMlp(
in_features=dim,
hidden_features=hidden_features * 2,
norm_layer=norm_layer if scale_mlp else None,
act_layer=nn.SiLU,
gate_last=False,
drop=proj_drop,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=hidden_features,
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
if self.gamma_1 is None:
x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
x = x + self.drop_path2(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class Eva(nn.Module):
""" Eva Vision Transformer w/ Abs & Rotary Pos Embed
This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
* EVA - abs pos embed, global avg pool
* EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
qkv_bias: bool = True,
qkv_fused: bool = True,
mlp_ratio: float = 4.,
swiglu_mlp: bool = False,
scale_mlp: bool = False,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_layer: Callable = LayerNorm,
init_values: Optional[float] = None,
use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False,
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
head_init_scale: float = 0.001,
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_rot_pos_emb:
ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
self.rope = RotaryEmbeddingCat(
embed_dim // num_heads,
in_pixels=False,
feat_shape=self.patch_embed.grid_size,
ref_feat_shape=ref_feat_shape,
)
else:
self.rope = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
EvaBlock(
dim=embed_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qkv_fused=qkv_fused,
mlp_ratio=mlp_ratio,
scale_mlp=scale_mlp,
swiglu_mlp=swiglu_mlp,
proj_drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
)
for i in range(depth)])
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
nwd = {'pos_embed', 'cls_token'}
return nwd
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
)
return matcher
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, rope=rot_pos_embed)
else:
x = blk(x, rope=rot_pos_embed)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(
state_dict,
model,
interpolation='bicubic',
antialias=True,
):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
state_dict = state_dict.get('model_ema', state_dict)
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('module', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
no_qkv = 'blocks.0.attn.q_proj.weight' in state_dict
mim_weights = 'mask_token' in state_dict
for k, v in state_dict.items():
if 'rope' in k:
# fixed embedding no need to load buffer from checkpoint
continue
if 'patch_embed.proj.weight' in k:
_, _, H, W = model.patch_embed.proj.weight.shape
if v.shape[-1] != W or v.shape[-2] != H:
v = resample_patch_embed(
v,
(H, W),
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
v = resample_abs_pos_embed(
v,
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
k = k.replace('mlp.ffn_ln', 'mlp.norm')
k = k.replace('mlp.w12', 'mlp.fc1')
k = k.replace('mlp.w1', 'mlp.fc1_g')
k = k.replace('mlp.w2', 'mlp.fc1_x')
k = k.replace('mlp.w3', 'mlp.fc2')
if no_qkv:
k = k.replace('q_bias', 'q_proj.bias')
k = k.replace('v_bias', 'v_proj.bias')
if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
if k == 'norm.weight' or k == 'norm.bias':
# try moving norm -> fc norm on fine-tune, probably a better starting point than new init
k = k.replace('norm', 'fc_norm')
else:
# skip pretrain mask token & head weights
continue
out_dict[k] = v
return out_dict
def _create_eva(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Eva models.')
model = build_model_with_cfg(
Eva, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
hf_hub_id='timm/',
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
hf_hub_id='timm/',
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
),
'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
),
'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
),
'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
input_size=(3, 336, 336), crop_pct=1.0,
),
'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
input_size=(3, 336, 336), crop_pct=1.0,
),
'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0,
),
'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0,
),
'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0,
),
'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
),
'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
),
'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
),
'eva02_tiny_patch14_224.mim_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
num_classes=0,
),
'eva02_small_patch14_224.mim_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
num_classes=0,
),
'eva02_base_patch14_224.mim_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
num_classes=0,
),
'eva02_large_patch14_224.mim_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
num_classes=0,
),
'eva02_large_patch14_224.mim_m38m': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
num_classes=0,
),
})
@register_model
def eva_giant_patch14_224(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_336(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_560(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva02_tiny_patch14_224(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=224,
patch_size=14,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4 * 2 / 3,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_small_patch14_224(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=224,
patch_size=14,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4 * 2 / 3,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_base_patch14_224(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=224,
patch_size=14,
embed_dim=768,
depth=12,
num_heads=12,
qkv_fused=False,
mlp_ratio=4 * 2 / 3,
scale_mlp=True,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_large_patch14_224(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=224,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4 * 2 / 3,
qkv_fused=False,
scale_mlp=True,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_tiny_patch14_336(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=336,
patch_size=14,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4 * 2 / 3,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_small_patch14_336(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=336,
patch_size=14,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4 * 2 / 3,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_base_patch14_448(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=448,
patch_size=14,
embed_dim=768,
depth=12,
num_heads=12,
qkv_fused=False,
mlp_ratio=4 * 2 / 3,
scale_mlp=True,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_large_patch14_448(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=448,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4 * 2 / 3,
qkv_fused=False,
scale_mlp=True,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model