Merge pull request #1746 from huggingface/eva02

Adding EVA02 weights and model defs
This commit is contained in:
Ross Wightman 2023-03-31 12:17:00 -07:00 committed by GitHub
commit 7326470514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1500 additions and 159 deletions

View File

@ -24,6 +24,30 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗ * ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch. * Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### March 31, 2023
* Add EVA-02 MIM pretrained and fine-tuned weights, push to HF hub and update model cards for all EVA models. First model over 90%!
| model |top1 |top5 |param_count|img_size|
|----------------------------------------------------|------|------|-----------|--------|
| [eva02_large_patch14_448.mim_m38m_ft_in22k_in1k](https://huggingface.co/timm/eva02_large_patch14_448.mim_m38m_ft_in1k) |90.054|99.042|305.08 |448 |
| eva02_large_patch14_448.mim_in22k_ft_in22k_in1k |89.946|99.01 |305.08 |448 |
| eva_giant_patch14_560.m30m_ft_in22k_in1k |89.792|98.992|1014.45 |560 |
| eva02_large_patch14_448.mim_in22k_ft_in1k |89.626|98.954|305.08 |448 |
| eva02_large_patch14_448.mim_m38m_ft_in1k |89.57 |98.918|305.08 |448 |
| eva_giant_patch14_336.m30m_ft_in22k_in1k |89.56 |98.956|1013.01 |336 |
| eva_giant_patch14_336.clip_ft_in1k |89.466|98.82 |1013.01 |336 |
| eva_large_patch14_336.in22k_ft_in22k_in1k |89.214|98.854|304.53 |336 |
| eva_giant_patch14_224.clip_ft_in1k |88.882|98.678|1012.56 |224 |
| eva02_base_patch14_448.mim_in22k_ft_in22k_in1k |88.692|98.722|87.12 |448 |
| eva_large_patch14_336.in22k_ft_in1k |88.652|98.722|304.53 |336 |
| eva_large_patch14_196.in22k_ft_in22k_in1k |88.592|98.656|304.14 |196 |
| eva02_base_patch14_448.mim_in22k_ft_in1k |88.23 |98.564|87.12 |448 |
| eva_large_patch14_196.in22k_ft_in1k |87.934|98.504|304.14 |196 |
| eva02_small_patch14_336.mim_in22k_ft_in1k |85.74 |97.614|22.13 |336 |
| eva02_tiny_patch14_336.mim_in22k_ft_in1k |80.658|95.524|5.76 |336 |
* Multi-weight and HF hub for DeiT and MLP-Mixer based models
### March 22, 2023 ### March 22, 2023
* More weights pushed to HF hub along with multi-weight support, including: `regnet.py`, `rexnet.py`, `byobnet.py`, `resnetv2.py`, `swin_transformer.py`, `swin_transformer_v2.py`, `swin_transformer_v2_cr.py` * More weights pushed to HF hub along with multi-weight support, including: `regnet.py`, `rexnet.py`, `byobnet.py`, `resnetv2.py`, `swin_transformer.py`, `swin_transformer_v2.py`, `swin_transformer_v2_cr.py`
* Swin Transformer models support feature extraction (NCHW feat maps for `swinv2_cr_*`, and NHWC for all others) and spatial embedding outputs. * Swin Transformer models support feature extraction (NCHW feat maps for `swinv2_cr_*`, and NHWC for all others) and spatial embedding outputs.

View File

@ -41,7 +41,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva_*', 'flexivit*', 'eva02*'
] ]
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
@ -51,12 +51,12 @@ if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FILTERS = [ EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*',
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge', 'regnet*1280', 'regnet*2560'] '*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'eva_giant*'] NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['vit_gi*'] NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128 TARGET_BWD_SIZE = 128

View File

@ -27,7 +27,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible,
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
@ -37,8 +37,8 @@ from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
FourierEmbed, RotaryEmbedding build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct from .separable_conv import SeparableConv2d, SeparableConvNormAct

View File

@ -19,6 +19,7 @@ class Mlp(nn.Module):
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=None,
bias=True, bias=True,
drop=0., drop=0.,
use_conv=False, use_conv=False,
@ -33,6 +34,7 @@ class Mlp(nn.Module):
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0]) self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
@ -55,9 +57,11 @@ class GluMlp(nn.Module):
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.Sigmoid, act_layer=nn.Sigmoid,
norm_layer=None,
bias=True, bias=True,
drop=0., drop=0.,
use_conv=False, use_conv=False,
gate_last=True,
): ):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
@ -67,10 +71,12 @@ class GluMlp(nn.Module):
drop_probs = to_2tuple(drop) drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.chunk_dim = 1 if use_conv else -1 self.chunk_dim = 1 if use_conv else -1
self.gate_last = gate_last # use second half of width for gate
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0]) self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
@ -82,9 +88,57 @@ class GluMlp(nn.Module):
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
x, gates = x.chunk(2, dim=self.chunk_dim) x1, x2 = x.chunk(2, dim=self.chunk_dim)
x = x * self.act(gates) x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2
x = self.drop1(x) x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SwiGLU(nn.Module):
""" SwiGLU
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
better matches some other common impl which makes mapping checkpoints simpler.
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.SiLU,
norm_layer=nn.LayerNorm,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
self.drop = nn.Dropout(drop)
def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
nn.init.ones_(self.fc1a.bias)
nn.init.normal_(self.fc1a.weight, std=1e-6)
def forward(self, x):
x_gate = self.fc1_g(x)
x = self.fc1_x(x)
x = self.act(x_gate) * x
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop2(x) x = self.drop2(x)
return x return x
@ -99,6 +153,7 @@ class GatedMlp(nn.Module):
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=None,
gate_layer=None, gate_layer=None,
bias=True, bias=True,
drop=0., drop=0.,
@ -118,6 +173,7 @@ class GatedMlp(nn.Module):
hidden_features = hidden_features // 2 # FIXME base reduction on gate property? hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
else: else:
self.gate = nn.Identity() self.gate = nn.Identity()
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
@ -126,6 +182,7 @@ class GatedMlp(nn.Module):
x = self.act(x) x = self.act(x)
x = self.drop1(x) x = self.drop1(x)
x = self.gate(x) x = self.gate(x)
x = self.norm(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop2(x) x = self.drop2(x)
return x return x

View File

@ -23,15 +23,15 @@ def pixel_freq_bands(
return bands * torch.pi return bands * torch.pi
def inv_freq_bands( def freq_bands(
num_bands: int, num_bands: int,
temperature: float = 100000., temperature: float = 10000.,
step: int = 2, step: int = 2,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return inv_freq return bands
def build_sincos2d_pos_embed( def build_sincos2d_pos_embed(
@ -59,12 +59,12 @@ def build_sincos2d_pos_embed(
""" """
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4 pos_dim = dim // 4
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord: if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack( grid = torch.stack(torch.meshgrid(
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) [torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim? # FIXME add support for unflattened spatial dim?
@ -78,18 +78,49 @@ def build_fourier_pos_embed(
bands: Optional[torch.Tensor] = None, bands: Optional[torch.Tensor] = None,
num_bands: int = 64, num_bands: int = 64,
max_res: int = 224, max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False, linear_bands: bool = False,
include_grid: bool = False, include_grid: bool = False,
concat_out: bool = True,
in_pixels: bool = True, in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""
Args:
feat_shape: Feature shape for embedding.
bands: Pre-calculated frequency bands.
num_bands: Number of frequency bands (determines output dim).
max_res: Maximum resolution for pixel based freq.
temperature: Temperature for non-pixel freq.
linear_bands: Linear band spacing for pixel based freq.
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
dtype: Output dtype.
device: Output device.
Returns:
"""
if bands is None: if bands is None:
if in_pixels: if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) bands = pixel_freq_bands(
num_bands,
float(max_res),
linear_bands=linear_bands,
dtype=dtype,
device=device,
)
else: else:
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) bands = freq_bands(
num_bands,
temperature=temperature,
step=1,
dtype=dtype,
device=device,
)
else: else:
if device is None: if device is None:
device = bands.device device = bands.device
@ -97,31 +128,42 @@ def build_fourier_pos_embed(
dtype = bands.dtype dtype = bands.dtype
if in_pixels: if in_pixels:
grid = torch.stack(torch.meshgrid( t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
else: else:
grid = torch.stack(torch.meshgrid( t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
grid = torch.stack(torch.meshgrid(t), dim=-1)
grid = grid.unsqueeze(-1) grid = grid.unsqueeze(-1)
pos = grid * bands pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos() pos_sin, pos_cos = pos.sin(), pos.cos()
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
if concat_out:
out = torch.cat(out, dim=-1)
return out return out
class FourierEmbed(nn.Module): class FourierEmbed(nn.Module):
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): def __init__(
self,
max_res: int = 224,
num_bands: int = 64,
concat_grid=True,
keep_spatial=False,
):
super().__init__() super().__init__()
self.max_res = max_res self.max_res = max_res
self.num_bands = num_bands self.num_bands = num_bands
self.concat_grid = concat_grid self.concat_grid = concat_grid
self.keep_spatial = keep_spatial self.keep_spatial = keep_spatial
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) self.register_buffer(
'bands',
pixel_freq_bands(max_res, num_bands),
persistent=False,
)
def forward(self, x): def forward(self, x):
B, C = x.shape[:2] B, C = x.shape[:2]
@ -131,7 +173,9 @@ class FourierEmbed(nn.Module):
self.bands, self.bands,
include_grid=self.concat_grid, include_grid=self.concat_grid,
dtype=x.dtype, dtype=x.dtype,
device=x.device) device=x.device,
)
emb = torch.cat(emb, dim=-1)
emb = emb.transpose(-1, -2).flatten(len(feat_shape)) emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1) batch_expand = (B,) + (-1,) * (x.ndim - 1)
@ -159,38 +203,57 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
return [t * cos_emb + rot(t) * sin_emb for t in x] return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_split(x: torch.Tensor, emb): def apply_rot_embed_cat(x: torch.Tensor, emb):
split = emb.shape[-1] // 2 sin_emb, cos_emb = emb.tensor_split(2, -1)
return x * emb[:, :split] + rot(x) * emb[:, split:] return x * cos_emb + rot(x) * sin_emb
def build_rotary_pos_embed( def build_rotary_pos_embed(
feat_shape: List[int], feat_shape: List[int],
bands: Optional[torch.Tensor] = None, bands: Optional[torch.Tensor] = None,
dim: int = 64, dim: int = 64,
max_freq: float = 224, max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False, linear_bands: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
): ):
""" """
NOTE: shape arg should include spatial dim only
"""
feat_shape = torch.Size(feat_shape)
Args:
feat_shape: Spatial shape of the target tensor for embedding.
bands: Optional pre-generated frequency bands
dim: Output dimension of embedding tensor.
max_res: Maximum resolution for pixel mode.
temperature: Temperature (inv freq) for non-pixel mode
linear_bands: Linearly (instead of log) spaced bands for pixel mode
in_pixels: Pixel vs language (inv freq) mode.
dtype: Output dtype.
device: Output device.
Returns:
"""
sin_emb, cos_emb = build_fourier_pos_embed( sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape, feat_shape,
bands=bands, bands=bands,
num_bands=dim // 4, num_bands=dim // 4,
max_res=max_freq, max_res=max_res,
temperature=temperature,
linear_bands=linear_bands, linear_bands=linear_bands,
concat_out=False, in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
N = feat_shape.numel() num_spatial_dim = 1
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) for x in feat_shape:
num_spatial_dim *= x
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb return sin_emb, cos_emb
@ -205,15 +268,164 @@ class RotaryEmbedding(nn.Module):
* https://blog.eleuther.ai/rotary-embeddings/ * https://blog.eleuther.ai/rotary-embeddings/
""" """
def __init__(self, dim, max_res=224, linear_bands: bool = False): def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
def get_embed(self, shape: List[int]): if feat_shape is None:
return build_rotary_pos_embed(shape, self.bands) # only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.pos_embed_sin = None
self.pos_embed_cos = None
else:
# cache full sin/cos embeddings if shape provided up front
emb_sin, emb_cos = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
self.bands = None
self.register_buffer(
'pos_embed_sin',
emb_sin,
persistent=False,
)
self.register_buffer(
'pos_embed_cos',
emb_cos,
persistent=False,
)
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
assert shape is not None
return build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
)
else:
return self.pos_embed_sin, self.pos_embed_cos
def forward(self, x): def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2 # assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:]) sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb) return apply_rot_embed(x, sin_emb, cos_emb)
class RotaryEmbeddingCat(nn.Module):
""" Rotary position embedding w/ concatenatd sin & cos
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__()
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
if feat_shape is None:
# only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.embed = None
else:
# cache full sin/cos embeddings if shape provided up front
embeds = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
self.bands = None
self.register_buffer(
'pos_embed',
torch.cat(embeds, -1),
persistent=False,
)
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
assert shape is not None
embeds = build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
)
return torch.cat(embeds, -1)
else:
return self.pos_embed
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
return apply_rot_embed_cat(x, pos_embed)

View File

@ -17,6 +17,7 @@ from .edgenext import *
from .efficientformer import * from .efficientformer import *
from .efficientformer_v2 import * from .efficientformer_v2 import *
from .efficientnet import * from .efficientnet import *
from .eva import *
from .focalnet import * from .focalnet import *
from .gcvit import * from .gcvit import *
from .ghostnet import * from .ghostnet import *

View File

@ -316,8 +316,16 @@ def generate_readme(model_card: dict, model_name: str):
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']: if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n' readme_text += 'datasets:\n'
if isinstance(model_card['details']['Dataset'], (tuple, list)):
for d in model_card['details']['Dataset']:
readme_text += f"- {d.lower()}\n"
else:
readme_text += f"- {model_card['details']['Dataset'].lower()}\n" readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
if 'Pretrain Dataset' in model_card['details']: if 'Pretrain Dataset' in model_card['details']:
if isinstance(model_card['details']['Pretrain Dataset'], (tuple, list)):
for d in model_card['details']['Pretrain Dataset']:
readme_text += f"- {d.lower()}\n"
else:
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n" readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
readme_text += "---\n" readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n" readme_text += f"# Model card for {model_name}\n"

View File

@ -21,17 +21,6 @@ archivePrefix={arXiv},
primaryClass={cs.CV} primaryClass={cs.CV}
} }
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
@article{EVA,
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
Tiejun and Wang, Xinlong and Cao, Yue},
journal={arXiv preprint arXiv:2211.07636},
year={2022}
}
At this point only the 1k fine-tuned classification weights and model configs have been added, At this point only the 1k fine-tuned classification weights and model configs have been added,
see original source above for pre-training models and procedure. see original source above for pre-training models and procedure.
@ -49,19 +38,18 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
# https://github.com/facebookresearch/dino # https://github.com/facebookresearch/dino
# --------------------------------------------------------' # --------------------------------------------------------'
# EVA models Copyright (c) 2022 BAAI-Vision
import math import math
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import Callable, Final, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
from .vision_transformer import checkpoint_filter_fn from .vision_transformer import checkpoint_filter_fn
@ -92,9 +80,18 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
class Attention(nn.Module): class Attention(nn.Module):
fast_attn: Final[bool]
def __init__( def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., self,
proj_drop=0., window_size=None, attn_head_dim=None): dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
window_size: Optional[Tuple[int, int]] = None,
attn_head_dim: Optional[int] = None,
):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
@ -102,6 +99,7 @@ class Attention(nn.Module):
head_dim = attn_head_dim head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5 self.scale = head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias: if qkv_bias:
@ -142,8 +140,24 @@ class Attention(nn.Module):
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
if self.fast_attn:
if self.relative_position_bias_table is not None:
rel_pos_bias = self._get_rel_pos_bias()
if shared_rel_pos_bias is not None:
rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
elif shared_rel_pos_bias is not None:
rel_pos_bias = shared_rel_pos_bias
else:
rel_pos_bias = None
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=rel_pos_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale q = q * self.scale
attn = (q @ k.transpose(-2, -1)) attn = (q @ k.transpose(-2, -1))
@ -154,8 +168,9 @@ class Attention(nn.Module):
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -164,19 +179,53 @@ class Attention(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., self,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, dim: int,
window_size=None, attn_head_dim=None): num_heads: int,
qkv_bias: bool = False,
mlp_ratio: float = 4.,
scale_mlp: bool = False,
swiglu_mlp: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
init_values: Optional[float] = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
window_size: Optional[Tuple[int, int]] = None,
attn_head_dim: Optional[int] = None,
):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention( self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, dim,
window_size=window_size, attn_head_dim=attn_head_dim) num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
window_size=window_size,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) if swiglu_mlp:
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.mlp = SwiGLU(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if init_values: if init_values:
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
@ -186,11 +235,11 @@ class Block(nn.Module):
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
if self.gamma_1 is None: if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path2(self.mlp(self.norm2(x)))
else: else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
return x return x
@ -216,19 +265,42 @@ class Beit(nn.Module):
""" """
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', self,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., img_size: Union[int, Tuple[int, int]] = 224,
attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), patch_size: Union[int, Tuple[int, int]] = 16,
init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, in_chans: int = 3,
head_init_scale=0.001): num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
qkv_bias: bool = True,
mlp_ratio: float = 4.,
swiglu_mlp: bool = False,
scale_mlp: bool = False,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_layer: Callable = LayerNorm,
init_values: Optional[float] = None,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = False,
head_init_scale: float = 0.001,
):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@ -237,27 +309,41 @@ class Beit(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias: if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads) self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.grid_size,
num_heads=num_heads,
)
else: else:
self.rel_pos_bias = None self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, dim=embed_dim,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, num_heads=num_heads,
init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None) qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
scale_mlp=scale_mlp,
swiglu_mlp=swiglu_mlp,
proj_drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
)
for i in range(depth)]) for i in range(depth)])
use_fc_norm = self.global_pool == 'avg' use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights) self.apply(self._init_weights)
if self.pos_embed is not None: if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.fix_init_weight() self.fix_init_weight()
if isinstance(self.head, nn.Linear): if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02) trunc_normal_(self.head.weight, std=.02)
@ -328,11 +414,9 @@ class Beit(nn.Module):
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.fc_norm is not None: if self.global_pool:
x = x[:, 1:].mean(dim=1) x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x) x = self.fc_norm(x)
else:
x = x[:, 0]
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
def forward(self, x): def forward(self, x):
@ -405,27 +489,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
), ),
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
}) })
@ -509,30 +572,3 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def eva_giant_patch14_224(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_336(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_560(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs)
return model

1003
timm/models/eva.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1216,22 +1216,22 @@ default_cfgs = generate_default_cfgs({
# https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
hf_hub_id='timm/', hf_hub_id='timm/', license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 196, 196), crop_pct=1.0), input_size=(3, 196, 196), crop_pct=1.0),
'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
hf_hub_id='timm/', hf_hub_id='timm/', license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_large_patch14_196.in22k_ft_in1k': _cfg( 'eva_large_patch14_196.in22k_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
hf_hub_id='timm/', hf_hub_id='timm/', license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 196, 196), crop_pct=1.0), input_size=(3, 196, 196), crop_pct=1.0),
'eva_large_patch14_336.in22k_ft_in1k': _cfg( 'eva_large_patch14_336.in22k_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
hf_hub_id='timm/', hf_hub_id='timm/', license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),