Adding EVA02 weights and model defs, move beit based eva_giant to same eva.py file. Cleanup rotary pos, add lang oriented freq bands to be compat with eva design choice. Fix #1738

This commit is contained in:
Ross Wightman 2023-03-27 17:14:58 -07:00
parent 56b90317cd
commit 3863d63516
6 changed files with 1181 additions and 147 deletions

View File

@ -27,7 +27,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible,
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
@ -37,8 +37,8 @@ from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
FourierEmbed, RotaryEmbedding build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct from .separable_conv import SeparableConv2d, SeparableConvNormAct

View File

@ -19,6 +19,7 @@ class Mlp(nn.Module):
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=None,
bias=True, bias=True,
drop=0., drop=0.,
use_conv=False, use_conv=False,
@ -33,6 +34,7 @@ class Mlp(nn.Module):
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0]) self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
@ -55,9 +57,11 @@ class GluMlp(nn.Module):
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.Sigmoid, act_layer=nn.Sigmoid,
norm_layer=None,
bias=True, bias=True,
drop=0., drop=0.,
use_conv=False, use_conv=False,
gate_last=True,
): ):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
@ -67,10 +71,12 @@ class GluMlp(nn.Module):
drop_probs = to_2tuple(drop) drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.chunk_dim = 1 if use_conv else -1 self.chunk_dim = 1 if use_conv else -1
self.gate_last = gate_last # use second half of width for gate
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0]) self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
@ -82,9 +88,57 @@ class GluMlp(nn.Module):
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
x, gates = x.chunk(2, dim=self.chunk_dim) x1, x2 = x.chunk(2, dim=self.chunk_dim)
x = x * self.act(gates) x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2
x = self.drop1(x) x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SwiGLU(nn.Module):
""" SwiGLU
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
better matches some other common impl which makes mapping checkpoints simpler.
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.SiLU,
norm_layer=nn.LayerNorm,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
self.drop = nn.Dropout(drop)
def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
nn.init.ones_(self.fc1a.bias)
nn.init.normal_(self.fc1a.weight, std=1e-6)
def forward(self, x):
x_gate = self.fc1_g(x)
x = self.fc1_x(x)
x = self.act(x_gate) * x
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop2(x) x = self.drop2(x)
return x return x
@ -99,6 +153,7 @@ class GatedMlp(nn.Module):
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=None,
gate_layer=None, gate_layer=None,
bias=True, bias=True,
drop=0., drop=0.,
@ -118,6 +173,7 @@ class GatedMlp(nn.Module):
hidden_features = hidden_features // 2 # FIXME base reduction on gate property? hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
else: else:
self.gate = nn.Identity() self.gate = nn.Identity()
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
@ -126,6 +182,7 @@ class GatedMlp(nn.Module):
x = self.act(x) x = self.act(x)
x = self.drop1(x) x = self.drop1(x)
x = self.gate(x) x = self.gate(x)
x = self.norm(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop2(x) x = self.drop2(x)
return x return x

View File

@ -23,15 +23,15 @@ def pixel_freq_bands(
return bands * torch.pi return bands * torch.pi
def inv_freq_bands( def freq_bands(
num_bands: int, num_bands: int,
temperature: float = 100000., temperature: float = 10000.,
step: int = 2, step: int = 2,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return inv_freq return bands
def build_sincos2d_pos_embed( def build_sincos2d_pos_embed(
@ -59,12 +59,12 @@ def build_sincos2d_pos_embed(
""" """
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4 pos_dim = dim // 4
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord: if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack( grid = torch.stack(torch.meshgrid(
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) [torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim? # FIXME add support for unflattened spatial dim?
@ -78,18 +78,49 @@ def build_fourier_pos_embed(
bands: Optional[torch.Tensor] = None, bands: Optional[torch.Tensor] = None,
num_bands: int = 64, num_bands: int = 64,
max_res: int = 224, max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False, linear_bands: bool = False,
include_grid: bool = False, include_grid: bool = False,
concat_out: bool = True,
in_pixels: bool = True, in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""
Args:
feat_shape: Feature shape for embedding.
bands: Pre-calculated frequency bands.
num_bands: Number of frequency bands (determines output dim).
max_res: Maximum resolution for pixel based freq.
temperature: Temperature for non-pixel freq.
linear_bands: Linear band spacing for pixel based freq.
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
dtype: Output dtype.
device: Output device.
Returns:
"""
if bands is None: if bands is None:
if in_pixels: if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) bands = pixel_freq_bands(
num_bands,
float(max_res),
linear_bands=linear_bands,
dtype=dtype,
device=device,
)
else: else:
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) bands = freq_bands(
num_bands,
temperature=temperature,
step=1,
dtype=dtype,
device=device,
)
else: else:
if device is None: if device is None:
device = bands.device device = bands.device
@ -97,31 +128,42 @@ def build_fourier_pos_embed(
dtype = bands.dtype dtype = bands.dtype
if in_pixels: if in_pixels:
grid = torch.stack(torch.meshgrid( t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
else: else:
grid = torch.stack(torch.meshgrid( t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
grid = torch.stack(torch.meshgrid(t), dim=-1)
grid = grid.unsqueeze(-1) grid = grid.unsqueeze(-1)
pos = grid * bands pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos() pos_sin, pos_cos = pos.sin(), pos.cos()
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
if concat_out:
out = torch.cat(out, dim=-1)
return out return out
class FourierEmbed(nn.Module): class FourierEmbed(nn.Module):
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): def __init__(
self,
max_res: int = 224,
num_bands: int = 64,
concat_grid=True,
keep_spatial=False,
):
super().__init__() super().__init__()
self.max_res = max_res self.max_res = max_res
self.num_bands = num_bands self.num_bands = num_bands
self.concat_grid = concat_grid self.concat_grid = concat_grid
self.keep_spatial = keep_spatial self.keep_spatial = keep_spatial
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) self.register_buffer(
'bands',
pixel_freq_bands(max_res, num_bands),
persistent=False,
)
def forward(self, x): def forward(self, x):
B, C = x.shape[:2] B, C = x.shape[:2]
@ -131,7 +173,9 @@ class FourierEmbed(nn.Module):
self.bands, self.bands,
include_grid=self.concat_grid, include_grid=self.concat_grid,
dtype=x.dtype, dtype=x.dtype,
device=x.device) device=x.device,
)
emb = torch.cat(emb, dim=-1)
emb = emb.transpose(-1, -2).flatten(len(feat_shape)) emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1) batch_expand = (B,) + (-1,) * (x.ndim - 1)
@ -159,38 +203,57 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
return [t * cos_emb + rot(t) * sin_emb for t in x] return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_split(x: torch.Tensor, emb): def apply_rot_embed_cat(x: torch.Tensor, emb):
split = emb.shape[-1] // 2 sin_emb, cos_emb = emb.tensor_split(2, -1)
return x * emb[:, :split] + rot(x) * emb[:, split:] return x * cos_emb + rot(x) * sin_emb
def build_rotary_pos_embed( def build_rotary_pos_embed(
feat_shape: List[int], feat_shape: List[int],
bands: Optional[torch.Tensor] = None, bands: Optional[torch.Tensor] = None,
dim: int = 64, dim: int = 64,
max_freq: float = 224, max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False, linear_bands: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
): ):
""" """
NOTE: shape arg should include spatial dim only
"""
feat_shape = torch.Size(feat_shape)
Args:
feat_shape: Spatial shape of the target tensor for embedding.
bands: Optional pre-generated frequency bands
dim: Output dimension of embedding tensor.
max_res: Maximum resolution for pixel mode.
temperature: Temperature (inv freq) for non-pixel mode
linear_bands: Linearly (instead of log) spaced bands for pixel mode
in_pixels: Pixel vs language (inv freq) mode.
dtype: Output dtype.
device: Output device.
Returns:
"""
sin_emb, cos_emb = build_fourier_pos_embed( sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape, feat_shape,
bands=bands, bands=bands,
num_bands=dim // 4, num_bands=dim // 4,
max_res=max_freq, max_res=max_res,
temperature=temperature,
linear_bands=linear_bands, linear_bands=linear_bands,
concat_out=False, in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
N = feat_shape.numel() num_spatial_dim = 1
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) for x in feat_shape:
num_spatial_dim *= x
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb return sin_emb, cos_emb
@ -205,15 +268,164 @@ class RotaryEmbedding(nn.Module):
* https://blog.eleuther.ai/rotary-embeddings/ * https://blog.eleuther.ai/rotary-embeddings/
""" """
def __init__(self, dim, max_res=224, linear_bands: bool = False): def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
def get_embed(self, shape: List[int]): if feat_shape is None:
return build_rotary_pos_embed(shape, self.bands) # only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.pos_embed_sin = None
self.pos_embed_cos = None
else:
# cache full sin/cos embeddings if shape provided up front
emb_sin, emb_cos = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
self.bands = None
self.register_buffer(
'pos_embed_sin',
emb_sin,
persistent=False,
)
self.register_buffer(
'pos_embed_cos',
emb_cos,
persistent=False,
)
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
assert shape is not None
return build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
)
else:
return self.pos_embed_sin, self.pos_embed_cos
def forward(self, x): def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2 # assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:]) sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb) return apply_rot_embed(x, sin_emb, cos_emb)
class RotaryEmbeddingCat(nn.Module):
""" Rotary position embedding w/ concatenatd sin & cos
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__()
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
if feat_shape is None:
# only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.embed = None
else:
# cache full sin/cos embeddings if shape provided up front
embeds = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
self.bands = None
self.register_buffer(
'pos_embed',
torch.cat(embeds, -1),
persistent=False,
)
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
assert shape is not None
embeds = build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
)
return torch.cat(embeds, -1)
else:
return self.pos_embed
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
return apply_rot_embed_cat(x, pos_embed)

View File

@ -17,6 +17,7 @@ from .edgenext import *
from .efficientformer import * from .efficientformer import *
from .efficientformer_v2 import * from .efficientformer_v2 import *
from .efficientnet import * from .efficientnet import *
from .eva import *
from .focalnet import * from .focalnet import *
from .gcvit import * from .gcvit import *
from .ghostnet import * from .ghostnet import *

View File

@ -21,17 +21,6 @@ archivePrefix={arXiv},
primaryClass={cs.CV} primaryClass={cs.CV}
} }
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
@article{EVA,
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
Tiejun and Wang, Xinlong and Cao, Yue},
journal={arXiv preprint arXiv:2211.07636},
year={2022}
}
At this point only the 1k fine-tuned classification weights and model configs have been added, At this point only the 1k fine-tuned classification weights and model configs have been added,
see original source above for pre-training models and procedure. see original source above for pre-training models and procedure.
@ -49,19 +38,18 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
# https://github.com/facebookresearch/dino # https://github.com/facebookresearch/dino
# --------------------------------------------------------' # --------------------------------------------------------'
# EVA models Copyright (c) 2022 BAAI-Vision
import math import math
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import Callable, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
from .vision_transformer import checkpoint_filter_fn from .vision_transformer import checkpoint_filter_fn
@ -93,8 +81,15 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., self,
proj_drop=0., window_size=None, attn_head_dim=None): dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
window_size: Optional[Tuple[int, int]] = None,
attn_head_dim: Optional[int] = None,
):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
@ -102,6 +97,7 @@ class Attention(nn.Module):
head_dim = attn_head_dim head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5 self.scale = head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias: if qkv_bias:
@ -142,8 +138,24 @@ class Attention(nn.Module):
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
if self.fast_attn:
if self.relative_position_bias_table is not None:
rel_pos_bias = self._get_rel_pos_bias()
if shared_rel_pos_bias is not None:
rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
elif shared_rel_pos_bias is not None:
rel_pos_bias = shared_rel_pos_bias
else:
rel_pos_bias = None
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=rel_pos_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale q = q * self.scale
attn = (q @ k.transpose(-2, -1)) attn = (q @ k.transpose(-2, -1))
@ -154,8 +166,9 @@ class Attention(nn.Module):
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -164,19 +177,53 @@ class Attention(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., self,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, dim: int,
window_size=None, attn_head_dim=None): num_heads: int,
qkv_bias: bool = False,
mlp_ratio: float = 4.,
scale_mlp: bool = False,
swiglu_mlp: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
init_values: Optional[float] = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
window_size: Optional[Tuple[int, int]] = None,
attn_head_dim: Optional[int] = None,
):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention( self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, dim,
window_size=window_size, attn_head_dim=attn_head_dim) num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
window_size=window_size,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) if swiglu_mlp:
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.mlp = SwiGLU(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if init_values: if init_values:
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
@ -186,11 +233,11 @@ class Block(nn.Module):
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
if self.gamma_1 is None: if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path2(self.mlp(self.norm2(x)))
else: else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
return x return x
@ -216,19 +263,42 @@ class Beit(nn.Module):
""" """
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', self,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., img_size: Union[int, Tuple[int, int]] = 224,
attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), patch_size: Union[int, Tuple[int, int]] = 16,
init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, in_chans: int = 3,
head_init_scale=0.001): num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
qkv_bias: bool = True,
mlp_ratio: float = 4.,
swiglu_mlp: bool = False,
scale_mlp: bool = False,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_layer: Callable = LayerNorm,
init_values: Optional[float] = None,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = False,
head_init_scale: float = 0.001,
):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@ -237,27 +307,41 @@ class Beit(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias: if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads) self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.grid_size,
num_heads=num_heads,
)
else: else:
self.rel_pos_bias = None self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, dim=embed_dim,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, num_heads=num_heads,
init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None) qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
scale_mlp=scale_mlp,
swiglu_mlp=swiglu_mlp,
proj_drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
)
for i in range(depth)]) for i in range(depth)])
use_fc_norm = self.global_pool == 'avg' use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights) self.apply(self._init_weights)
if self.pos_embed is not None: if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.fix_init_weight() self.fix_init_weight()
if isinstance(self.head, nn.Linear): if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02) trunc_normal_(self.head.weight, std=.02)
@ -328,11 +412,9 @@ class Beit(nn.Module):
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.fc_norm is not None: if self.global_pool:
x = x[:, 1:].mean(dim=1) x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x) x = self.fc_norm(x)
else:
x = x[:, 0]
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
def forward(self, x): def forward(self, x):
@ -405,27 +487,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
), ),
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
}) })
@ -509,30 +570,3 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def eva_giant_patch14_224(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_336(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_560(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs)
return model

730
timm/models/eva.py Normal file
View File

@ -0,0 +1,730 @@
""" EVA
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
@article{EVA,
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
Tiejun and Wang, Xinlong and Cao, Yue},
journal={arXiv preprint arXiv:2211.07636},
year={2022}
}
EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331
@article{EVA02,
title={EVA-02: A Visual Representation for Neon Genesis},
author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue},
journal={arXiv preprint arXiv:2303.11331},
year={2023}
}
This file contains EVA & EVA02 model implementations evolved from BEiT, additional models in vision_transformer.py.
Modifications by / Copyright 2023 Ross Wightman, original copyrights below
"""
# EVA models Copyright (c) 2022 BAAI-Vision
# EVA02 models Copyright (c) 2023 BAAI-Vision
import math
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, RotaryEmbeddingCat, \
apply_rot_embed_cat, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple
from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model
__all__ = ['Eva']
class EvaAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qkv_fused: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
attn_head_dim: Optional[int] = None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if qkv_fused:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
self.q_proj = self.k_proj = self.v_proj = None
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = self.k_bias = self.v_bias = None
else:
self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
self.qkv = None
self.q_bias = self.k_bias = self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x,
rope: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
B, N, C = x.shape
if self.qkv is not None:
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
else:
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
if rope is not None:
q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
if self.fast_attn:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
if attn_mask is not None:
attn_mask = attn_mask.to(torch.bool)
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class EvaBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
qkv_bias: bool = True,
qkv_fused: bool = True,
mlp_ratio: float = 4.,
scale_mlp: bool = False,
swiglu_mlp: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
init_values: Optional[float] = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
attn_head_dim: Optional[int] = None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = EvaAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qkv_fused=qkv_fused,
attn_drop=attn_drop,
proj_drop=proj_drop,
attn_head_dim=attn_head_dim,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
hidden_features = int(dim * mlp_ratio)
if swiglu_mlp:
if scale_mlp:
# when norm in SwiGLU used, an impl with separate fc for gate & x is used
self.mlp = SwiGLU(
in_features=dim,
hidden_features=hidden_features,
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
else:
# w/o any extra norm, an impl with packed weights is used, matches existing GluMLP
self.mlp = GluMlp(
in_features=dim,
hidden_features=hidden_features * 2,
norm_layer=norm_layer if scale_mlp else None,
act_layer=nn.SiLU,
gate_last=False,
drop=proj_drop,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=hidden_features,
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp else None,
drop=proj_drop,
)
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
if self.gamma_1 is None:
x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
x = x + self.drop_path2(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class Eva(nn.Module):
""" Eva Vision Transformer w/ Abs & Rotary Pos Embed
This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
* EVA - abs pos embed, global avg pool
* EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
qkv_bias: bool = True,
qkv_fused: bool = True,
mlp_ratio: float = 4.,
swiglu_mlp: bool = False,
scale_mlp: bool = False,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_layer: Callable = LayerNorm,
init_values: Optional[float] = None,
use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False,
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
head_init_scale: float = 0.001,
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_rot_pos_emb:
ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
self.rope = RotaryEmbeddingCat(
embed_dim // num_heads,
in_pixels=False,
feat_shape=self.patch_embed.grid_size,
ref_feat_shape=ref_feat_shape,
)
else:
self.rope = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
EvaBlock(
dim=embed_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qkv_fused=qkv_fused,
mlp_ratio=mlp_ratio,
scale_mlp=scale_mlp,
swiglu_mlp=swiglu_mlp,
proj_drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
)
for i in range(depth)])
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
nwd = {'pos_embed', 'cls_token'}
return nwd
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
)
return matcher
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, rope=rot_pos_embed)
else:
x = blk(x, rope=rot_pos_embed)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(
state_dict,
model,
interpolation='bicubic',
antialias=True,
):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
state_dict = state_dict.get('model_ema', state_dict)
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('module', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
no_qkv = 'blocks.0.attn.q_proj.weight' in state_dict
mim_weights = 'mask_token' in state_dict
for k, v in state_dict.items():
if 'rope' in k:
# fixed embedding no need to load buffer from checkpoint
continue
if 'patch_embed.proj.weight' in k:
_, _, H, W = model.patch_embed.proj.weight.shape
if v.shape[-1] != W or v.shape[-2] != H:
v = resample_patch_embed(
v,
(H, W),
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
v = resample_abs_pos_embed(
v,
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
k = k.replace('mlp.ffn_ln', 'mlp.norm')
k = k.replace('mlp.w12', 'mlp.fc1')
k = k.replace('mlp.w1', 'mlp.fc1_g')
k = k.replace('mlp.w2', 'mlp.fc1_x')
k = k.replace('mlp.w3', 'mlp.fc2')
if no_qkv:
k = k.replace('q_bias', 'q_proj.bias')
k = k.replace('v_bias', 'v_proj.bias')
if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
if k == 'norm.weight' or k == 'norm.bias':
# try moving norm -> fc norm on fine-tune, probably a better starting point than new init
k = k.replace('norm', 'fc_norm')
else:
# skip pretrain mask token & head weights
continue
out_dict[k] = v
return out_dict
def _create_eva(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Eva models.')
model = build_model_with_cfg(
Eva, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
hf_hub_id='timm/',
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
hf_hub_id='timm/',
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
),
'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
),
'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
),
'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
input_size=(3, 336, 336), crop_pct=1.0,
),
'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
input_size=(3, 336, 336), crop_pct=1.0,
),
'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0,
),
'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0,
),
'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02',
hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0,
),
'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
),
'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
),
'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
),
'eva02_tiny_patch14_224.mim_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
num_classes=0,
),
'eva02_small_patch14_224.mim_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
num_classes=0,
),
'eva02_base_patch14_224.mim_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
num_classes=0,
),
'eva02_large_patch14_224.mim_in22k': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
num_classes=0,
),
'eva02_large_patch14_224.mim_m38m': _cfg(
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
num_classes=0,
),
})
@register_model
def eva_giant_patch14_224(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_336(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva_giant_patch14_560(pretrained=False, **kwargs):
""" EVA-g model https://arxiv.org/abs/2211.07636 """
model_kwargs = dict(
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs)
model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs)
return model
@register_model
def eva02_tiny_patch14_224(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=224,
patch_size=14,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4 * 2 / 3,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_small_patch14_224(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=224,
patch_size=14,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4 * 2 / 3,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_base_patch14_224(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=224,
patch_size=14,
embed_dim=768,
depth=12,
num_heads=12,
qkv_fused=False,
mlp_ratio=4 * 2 / 3,
scale_mlp=True,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_large_patch14_224(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=224,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4 * 2 / 3,
qkv_fused=False,
scale_mlp=True,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_tiny_patch14_336(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=336,
patch_size=14,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4 * 2 / 3,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_small_patch14_336(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=336,
patch_size=14,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4 * 2 / 3,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_base_patch14_448(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=448,
patch_size=14,
embed_dim=768,
depth=12,
num_heads=12,
qkv_fused=False,
mlp_ratio=4 * 2 / 3,
scale_mlp=True,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva02_large_patch14_448(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=448,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4 * 2 / 3,
qkv_fused=False,
scale_mlp=True,
swiglu_mlp=True,
use_rot_pos_emb=True,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model