diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index d4eab660..afae6415 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -27,7 +27,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d -from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp +from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ @@ -37,8 +37,8 @@ from .patch_embed import PatchEmbed, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords -from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \ - FourierEmbed, RotaryEmbedding +from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ + build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvNormAct diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index d0188291..c4edf1b1 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -19,6 +19,7 @@ class Mlp(nn.Module): hidden_features=None, out_features=None, act_layer=nn.GELU, + norm_layer=None, bias=True, drop=0., use_conv=False, @@ -33,6 +34,7 @@ class Mlp(nn.Module): self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) @@ -55,9 +57,11 @@ class GluMlp(nn.Module): hidden_features=None, out_features=None, act_layer=nn.Sigmoid, + norm_layer=None, bias=True, drop=0., use_conv=False, + gate_last=True, ): super().__init__() out_features = out_features or in_features @@ -67,10 +71,12 @@ class GluMlp(nn.Module): drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.chunk_dim = 1 if use_conv else -1 + self.gate_last = gate_last # use second half of width for gate self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity() self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) @@ -82,9 +88,57 @@ class GluMlp(nn.Module): def forward(self, x): x = self.fc1(x) - x, gates = x.chunk(2, dim=self.chunk_dim) - x = x * self.act(gates) + x1, x2 = x.chunk(2, dim=self.chunk_dim) + x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2 x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class SwiGLU(nn.Module): + """ SwiGLU + NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and + better matches some other common impl which makes mapping checkpoints simpler. + """ + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.SiLU, + norm_layer=nn.LayerNorm, + bias=True, + drop=0., + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + self.drop = nn.Dropout(drop) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + nn.init.ones_(self.fc1a.bias) + nn.init.normal_(self.fc1a.weight, std=1e-6) + + def forward(self, x): + x_gate = self.fc1_g(x) + x = self.fc1_x(x) + x = self.act(x_gate) * x + x = self.drop1(x) + x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x @@ -99,6 +153,7 @@ class GatedMlp(nn.Module): hidden_features=None, out_features=None, act_layer=nn.GELU, + norm_layer=None, gate_layer=None, bias=True, drop=0., @@ -118,6 +173,7 @@ class GatedMlp(nn.Module): hidden_features = hidden_features // 2 # FIXME base reduction on gate property? else: self.gate = nn.Identity() + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) @@ -126,6 +182,7 @@ class GatedMlp(nn.Module): x = self.act(x) x = self.drop1(x) x = self.gate(x) + x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 5603a5cd..a305aa8a 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -23,15 +23,15 @@ def pixel_freq_bands( return bands * torch.pi -def inv_freq_bands( +def freq_bands( num_bands: int, - temperature: float = 100000., + temperature: float = 10000., step: int = 2, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> torch.Tensor: - inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) - return inv_freq + bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) + return bands def build_sincos2d_pos_embed( @@ -59,12 +59,12 @@ def build_sincos2d_pos_embed( """ assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' pos_dim = dim // 4 - bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) + bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) if reverse_coord: feat_shape = feat_shape[::-1] # stack W, H instead of H, W - grid = torch.stack( - torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) + grid = torch.stack(torch.meshgrid( + [torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) # FIXME add support for unflattened spatial dim? @@ -78,18 +78,49 @@ def build_fourier_pos_embed( bands: Optional[torch.Tensor] = None, num_bands: int = 64, max_res: int = 224, + temperature: float = 10000., linear_bands: bool = False, include_grid: bool = False, - concat_out: bool = True, in_pixels: bool = True, + ref_feat_shape: Optional[List[int]] = None, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> List[torch.Tensor]: + """ + + Args: + feat_shape: Feature shape for embedding. + bands: Pre-calculated frequency bands. + num_bands: Number of frequency bands (determines output dim). + max_res: Maximum resolution for pixel based freq. + temperature: Temperature for non-pixel freq. + linear_bands: Linear band spacing for pixel based freq. + include_grid: Include the spatial grid in output. + in_pixels: Output in pixel freq. + ref_feat_shape: Reference feature shape for resize / fine-tune. + dtype: Output dtype. + device: Output device. + + Returns: + + """ if bands is None: if in_pixels: - bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) + bands = pixel_freq_bands( + num_bands, + float(max_res), + linear_bands=linear_bands, + dtype=dtype, + device=device, + ) else: - bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) + bands = freq_bands( + num_bands, + temperature=temperature, + step=1, + dtype=dtype, + device=device, + ) else: if device is None: device = bands.device @@ -97,31 +128,42 @@ def build_fourier_pos_embed( dtype = bands.dtype if in_pixels: - grid = torch.stack(torch.meshgrid( - [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape] else: - grid = torch.stack(torch.meshgrid( - [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] + + if ref_feat_shape is not None: + # eva's scheme for resizing rope embeddings (ref shape = pretrain) + t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] + + grid = torch.stack(torch.meshgrid(t), dim=-1) grid = grid.unsqueeze(-1) pos = grid * bands pos_sin, pos_cos = pos.sin(), pos.cos() - out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) - # FIXME torchscript doesn't like multiple return types, probably need to always cat? - if concat_out: - out = torch.cat(out, dim=-1) + out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] return out class FourierEmbed(nn.Module): - def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): + def __init__( + self, + max_res: int = 224, + num_bands: int = 64, + concat_grid=True, + keep_spatial=False, + ): super().__init__() self.max_res = max_res self.num_bands = num_bands self.concat_grid = concat_grid self.keep_spatial = keep_spatial - self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) + self.register_buffer( + 'bands', + pixel_freq_bands(max_res, num_bands), + persistent=False, + ) def forward(self, x): B, C = x.shape[:2] @@ -131,7 +173,9 @@ class FourierEmbed(nn.Module): self.bands, include_grid=self.concat_grid, dtype=x.dtype, - device=x.device) + device=x.device, + ) + emb = torch.cat(emb, dim=-1) emb = emb.transpose(-1, -2).flatten(len(feat_shape)) batch_expand = (B,) + (-1,) * (x.ndim - 1) @@ -159,38 +203,57 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): return [t * cos_emb + rot(t) * sin_emb for t in x] -def apply_rot_embed_split(x: torch.Tensor, emb): - split = emb.shape[-1] // 2 - return x * emb[:, :split] + rot(x) * emb[:, split:] +def apply_rot_embed_cat(x: torch.Tensor, emb): + sin_emb, cos_emb = emb.tensor_split(2, -1) + return x * cos_emb + rot(x) * sin_emb def build_rotary_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, dim: int = 64, - max_freq: float = 224, + max_res: int = 224, + temperature: float = 10000., linear_bands: bool = False, + in_pixels: bool = True, + ref_feat_shape: Optional[List[int]] = None, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): """ - NOTE: shape arg should include spatial dim only - """ - feat_shape = torch.Size(feat_shape) + Args: + feat_shape: Spatial shape of the target tensor for embedding. + bands: Optional pre-generated frequency bands + dim: Output dimension of embedding tensor. + max_res: Maximum resolution for pixel mode. + temperature: Temperature (inv freq) for non-pixel mode + linear_bands: Linearly (instead of log) spaced bands for pixel mode + in_pixels: Pixel vs language (inv freq) mode. + dtype: Output dtype. + device: Output device. + + Returns: + + """ sin_emb, cos_emb = build_fourier_pos_embed( feat_shape, bands=bands, num_bands=dim // 4, - max_res=max_freq, + max_res=max_res, + temperature=temperature, linear_bands=linear_bands, - concat_out=False, + in_pixels=in_pixels, + ref_feat_shape=ref_feat_shape, device=device, dtype=dtype, ) - N = feat_shape.numel() - sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) - cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) + num_spatial_dim = 1 + # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks + for x in feat_shape: + num_spatial_dim *= x + sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) + cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) return sin_emb, cos_emb @@ -205,15 +268,164 @@ class RotaryEmbedding(nn.Module): * https://blog.eleuther.ai/rotary-embeddings/ """ - def __init__(self, dim, max_res=224, linear_bands: bool = False): + def __init__( + self, + dim, + max_res=224, + temperature=10000, + in_pixels=True, + linear_bands: bool = False, + feat_shape: Optional[List[int]] = None, + ref_feat_shape: Optional[List[int]] = None, + ): super().__init__() self.dim = dim - self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) + self.max_res = max_res + self.temperature = temperature + self.in_pixels = in_pixels + self.feat_shape = feat_shape + self.ref_feat_shape = ref_feat_shape - def get_embed(self, shape: List[int]): - return build_rotary_pos_embed(shape, self.bands) + if feat_shape is None: + # only cache bands + if in_pixels: + bands = pixel_freq_bands( + dim // 4, + float(max_res), + linear_bands=linear_bands, + ) + else: + bands = freq_bands( + dim // 4, + temperature=temperature, + step=1, + ) + print(bands) + self.register_buffer( + 'bands', + bands, + persistent=False, + ) + self.pos_embed_sin = None + self.pos_embed_cos = None + else: + # cache full sin/cos embeddings if shape provided up front + emb_sin, emb_cos = build_rotary_pos_embed( + feat_shape=feat_shape, + dim=dim, + max_res=max_res, + linear_bands=linear_bands, + in_pixels=in_pixels, + ref_feat_shape=self.ref_feat_shape, + ) + self.bands = None + self.register_buffer( + 'pos_embed_sin', + emb_sin, + persistent=False, + ) + self.register_buffer( + 'pos_embed_cos', + emb_cos, + persistent=False, + ) + + def get_embed(self, shape: Optional[List[int]] = None): + if self.bands is not None: + # rebuild embeddings every call, use if target shape changes + assert shape is not None + return build_rotary_pos_embed( + shape, + self.bands, + in_pixels=self.in_pixels, + ) + else: + return self.pos_embed_sin, self.pos_embed_cos def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 sin_emb, cos_emb = self.get_embed(x.shape[2:]) return apply_rot_embed(x, sin_emb, cos_emb) + + +class RotaryEmbeddingCat(nn.Module): + """ Rotary position embedding w/ concatenatd sin & cos + + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + + def __init__( + self, + dim, + max_res=224, + temperature=10000, + in_pixels=True, + linear_bands: bool = False, + feat_shape: Optional[List[int]] = None, + ref_feat_shape: Optional[List[int]] = None, + ): + super().__init__() + self.dim = dim + self.max_res = max_res + self.temperature = temperature + self.in_pixels = in_pixels + self.feat_shape = feat_shape + self.ref_feat_shape = ref_feat_shape + + if feat_shape is None: + # only cache bands + if in_pixels: + bands = pixel_freq_bands( + dim // 4, + float(max_res), + linear_bands=linear_bands, + ) + else: + bands = freq_bands( + dim // 4, + temperature=temperature, + step=1, + ) + print(bands) + self.register_buffer( + 'bands', + bands, + persistent=False, + ) + self.embed = None + else: + # cache full sin/cos embeddings if shape provided up front + embeds = build_rotary_pos_embed( + feat_shape=feat_shape, + dim=dim, + max_res=max_res, + linear_bands=linear_bands, + in_pixels=in_pixels, + ref_feat_shape=self.ref_feat_shape, + ) + self.bands = None + self.register_buffer( + 'pos_embed', + torch.cat(embeds, -1), + persistent=False, + ) + + def get_embed(self, shape: Optional[List[int]] = None): + if self.bands is not None: + # rebuild embeddings every call, use if target shape changes + assert shape is not None + embeds = build_rotary_pos_embed( + shape, + self.bands, + in_pixels=self.in_pixels, + ) + return torch.cat(embeds, -1) + else: + return self.pos_embed + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + pos_embed = self.get_embed(x.shape[2:]) + return apply_rot_embed_cat(x, pos_embed) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index fba8d44f..33680704 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -17,6 +17,7 @@ from .edgenext import * from .efficientformer import * from .efficientformer_v2 import * from .efficientnet import * +from .eva import * from .focalnet import * from .gcvit import * from .ghostnet import * diff --git a/timm/models/beit.py b/timm/models/beit.py index a7084aeb..3ba67031 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -21,17 +21,6 @@ archivePrefix={arXiv}, primaryClass={cs.CV} } -EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636 - -@article{EVA, - title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale}, - author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang, - Tiejun and Wang, Xinlong and Cao, Yue}, - journal={arXiv preprint arXiv:2211.07636}, - year={2022} -} - - At this point only the 1k fine-tuned classification weights and model configs have been added, see original source above for pre-training models and procedure. @@ -49,19 +38,18 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below # https://github.com/facebookresearch/dino # --------------------------------------------------------' -# EVA models Copyright (c) 2022 BAAI-Vision - import math from functools import partial -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_ + from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model from .vision_transformer import checkpoint_filter_fn @@ -93,8 +81,15 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: class Attention(nn.Module): def __init__( - self, dim, num_heads=8, qkv_bias=False, attn_drop=0., - proj_drop=0., window_size=None, attn_head_dim=None): + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + window_size: Optional[Tuple[int, int]] = None, + attn_head_dim: Optional[int] = None, + ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -102,6 +97,7 @@ class Attention(nn.Module): head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = head_dim ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: @@ -142,20 +138,37 @@ class Attention(nn.Module): qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + if self.fast_attn: + if self.relative_position_bias_table is not None: + rel_pos_bias = self._get_rel_pos_bias() + if shared_rel_pos_bias is not None: + rel_pos_bias = rel_pos_bias + shared_rel_pos_bias + elif shared_rel_pos_bias is not None: + rel_pos_bias = shared_rel_pos_bias + else: + rel_pos_bias = None - if self.relative_position_bias_table is not None: - attn = attn + self._get_rel_pos_bias() - if shared_rel_pos_bias is not None: - attn = attn + shared_rel_pos_bias + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=rel_pos_bias, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.relative_position_bias_table is not None: + attn = attn + self._get_rel_pos_bias() + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -164,19 +177,53 @@ class Attention(nn.Module): class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, - window_size=None, attn_head_dim=None): + self, + dim: int, + num_heads: int, + qkv_bias: bool = False, + mlp_ratio: float = 4., + scale_mlp: bool = False, + swiglu_mlp: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + init_values: Optional[float] = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + window_size: Optional[Tuple[int, int]] = None, + attn_head_dim: Optional[int] = None, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - window_size=window_size, attn_head_dim=attn_head_dim) + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + window_size=window_size, + attn_head_dim=attn_head_dim, + ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + if swiglu_mlp: + self.mlp = SwiGLU( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + norm_layer=norm_layer if scale_mlp else None, + drop=proj_drop, + ) + else: + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + norm_layer=norm_layer if scale_mlp else None, + drop=proj_drop, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() if init_values: self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) @@ -186,11 +233,11 @@ class Block(nn.Module): def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): if self.gamma_1 is None: - x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) else: - x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) - x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x))) return x @@ -216,19 +263,42 @@ class Beit(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, - head_init_scale=0.001): + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + qkv_bias: bool = True, + mlp_ratio: float = 4., + swiglu_mlp: bool = False, + scale_mlp: bool = False, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Callable = LayerNorm, + init_values: Optional[float] = None, + use_abs_pos_emb: bool = True, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = False, + head_init_scale: float = 0.001, + ): super().__init__() self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 self.grad_checkpointing = False self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) @@ -237,27 +307,41 @@ class Beit(nn.Module): self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: - self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads) + self.rel_pos_bias = RelativePositionBias( + window_size=self.patch_embed.grid_size, + num_heads=num_heads, + ) else: self.rel_pos_bias = None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None) + dim=embed_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + scale_mlp=scale_mlp, + swiglu_mlp=swiglu_mlp, + proj_drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + window_size=self.patch_embed.grid_size if use_rel_pos_bias else None, + ) for i in range(depth)]) + use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) - self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) - # trunc_normal_(self.mask_token, std=.02) + self.fix_init_weight() if isinstance(self.head, nn.Linear): trunc_normal_(self.head.weight, std=.02) @@ -328,11 +412,9 @@ class Beit(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - if self.fc_norm is not None: - x = x[:, 1:].mean(dim=1) - x = self.fc_norm(x) - else: - x = x[:, 0] + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -405,27 +487,6 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), - - 'eva_giant_patch14_224.clip_ft_in1k': _cfg( - # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', - hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, - ), - 'eva_giant_patch14_336.clip_ft_in1k': _cfg( - # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', - hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), - 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( - # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', - hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), - 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( - # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', - hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'), }) @@ -509,30 +570,3 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs): use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) return model - - -@register_model -def eva_giant_patch14_224(pretrained=False, **kwargs): - """ EVA-g model https://arxiv.org/abs/2211.07636 """ - model_kwargs = dict( - patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) - model = _create_beit('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def eva_giant_patch14_336(pretrained=False, **kwargs): - """ EVA-g model https://arxiv.org/abs/2211.07636 """ - model_kwargs = dict( - patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) - model = _create_beit('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def eva_giant_patch14_560(pretrained=False, **kwargs): - """ EVA-g model https://arxiv.org/abs/2211.07636 """ - model_kwargs = dict( - patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) - model = _create_beit('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs) - return model diff --git a/timm/models/eva.py b/timm/models/eva.py new file mode 100644 index 00000000..2daa04c0 --- /dev/null +++ b/timm/models/eva.py @@ -0,0 +1,730 @@ +""" EVA + +EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636 + +@article{EVA, + title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale}, + author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang, + Tiejun and Wang, Xinlong and Cao, Yue}, + journal={arXiv preprint arXiv:2211.07636}, + year={2022} +} + +EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331 +@article{EVA02, + title={EVA-02: A Visual Representation for Neon Genesis}, + author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue}, + journal={arXiv preprint arXiv:2303.11331}, + year={2023} +} + +This file contains EVA & EVA02 model implementations evolved from BEiT, additional models in vision_transformer.py. + +Modifications by / Copyright 2023 Ross Wightman, original copyrights below +""" +# EVA models Copyright (c) 2022 BAAI-Vision +# EVA02 models Copyright (c) 2023 BAAI-Vision + +import math +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, RotaryEmbeddingCat, \ + apply_rot_embed_cat, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple + +from ._builder import build_model_with_cfg +from ._registry import generate_default_cfgs, register_model + +__all__ = ['Eva'] + + +class EvaAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qkv_fused: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + attn_head_dim: Optional[int] = None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = head_dim ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + + if qkv_fused: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + self.q_proj = self.k_proj = self.v_proj = None + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = self.k_bias = self.v_bias = None + else: + self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias) + self.qkv = None + self.q_bias = self.k_bias = self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x, + rope: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + B, N, C = x.shape + + if self.qkv is not None: + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim + else: + q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C + k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + + if rope is not None: + q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v) + k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v) + + if self.fast_attn: + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + if attn_mask is not None: + attn_mask = attn_mask.to(torch.bool) + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class EvaBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + qkv_fused: bool = True, + mlp_ratio: float = 4., + scale_mlp: bool = False, + swiglu_mlp: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + init_values: Optional[float] = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + attn_head_dim: Optional[int] = None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = EvaAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qkv_fused=qkv_fused, + attn_drop=attn_drop, + proj_drop=proj_drop, + attn_head_dim=attn_head_dim, + ) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + hidden_features = int(dim * mlp_ratio) + if swiglu_mlp: + if scale_mlp: + # when norm in SwiGLU used, an impl with separate fc for gate & x is used + self.mlp = SwiGLU( + in_features=dim, + hidden_features=hidden_features, + norm_layer=norm_layer if scale_mlp else None, + drop=proj_drop, + ) + else: + # w/o any extra norm, an impl with packed weights is used, matches existing GluMLP + self.mlp = GluMlp( + in_features=dim, + hidden_features=hidden_features * 2, + norm_layer=norm_layer if scale_mlp else None, + act_layer=nn.SiLU, + gate_last=False, + drop=proj_drop, + ) + else: + self.mlp = Mlp( + in_features=dim, + hidden_features=hidden_features, + act_layer=act_layer, + norm_layer=norm_layer if scale_mlp else None, + drop=proj_drop, + ) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None): + if self.gamma_1 is None: + x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)) + x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class Eva(nn.Module): + """ Eva Vision Transformer w/ Abs & Rotary Pos Embed + + This class implements the EVA and EVA02 models that were based on the BEiT ViT variant + * EVA - abs pos embed, global avg pool + * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer) + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + qkv_bias: bool = True, + qkv_fused: bool = True, + mlp_ratio: float = 4., + swiglu_mlp: bool = False, + scale_mlp: bool = False, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Callable = LayerNorm, + init_values: Optional[float] = None, + use_abs_pos_emb: bool = True, + use_rot_pos_emb: bool = False, + ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, + head_init_scale: float = 0.001, + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 + self.grad_checkpointing = False + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_rot_pos_emb: + ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None + self.rope = RotaryEmbeddingCat( + embed_dim // num_heads, + in_pixels=False, + feat_shape=self.patch_embed.grid_size, + ref_feat_shape=ref_feat_shape, + ) + else: + self.rope = None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + EvaBlock( + dim=embed_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qkv_fused=qkv_fused, + mlp_ratio=mlp_ratio, + scale_mlp=scale_mlp, + swiglu_mlp=swiglu_mlp, + proj_drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + ) + for i in range(depth)]) + + use_fc_norm = self.global_pool == 'avg' + self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + + self.fix_init_weight() + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=.02) + self.head.weight.data.mul_(head_init_scale) + self.head.bias.data.mul_(head_init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + nwd = {'pos_embed', 'cls_token'} + return nwd + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))], + ) + return matcher + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rot_pos_embed = self.rope.get_embed() if self.rope is not None else None + + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed) + else: + x = blk(x, rope=rot_pos_embed) + + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn( + state_dict, + model, + interpolation='bicubic', + antialias=True, +): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + state_dict = state_dict.get('model_ema', state_dict) + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('module', state_dict) + state_dict = state_dict.get('state_dict', state_dict) + no_qkv = 'blocks.0.attn.q_proj.weight' in state_dict + mim_weights = 'mask_token' in state_dict + + for k, v in state_dict.items(): + if 'rope' in k: + # fixed embedding no need to load buffer from checkpoint + continue + + if 'patch_embed.proj.weight' in k: + _, _, H, W = model.patch_embed.proj.weight.shape + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) + v = resample_abs_pos_embed( + v, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + + k = k.replace('mlp.ffn_ln', 'mlp.norm') + k = k.replace('mlp.w12', 'mlp.fc1') + k = k.replace('mlp.w1', 'mlp.fc1_g') + k = k.replace('mlp.w2', 'mlp.fc1_x') + k = k.replace('mlp.w3', 'mlp.fc2') + if no_qkv: + k = k.replace('q_bias', 'q_proj.bias') + k = k.replace('v_bias', 'v_proj.bias') + + if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'): + if k == 'norm.weight' or k == 'norm.bias': + # try moving norm -> fc norm on fine-tune, probably a better starting point than new init + k = k.replace('norm', 'fc_norm') + else: + # skip pretrain mask token & head weights + continue + + out_dict[k] = v + + return out_dict + + +def _create_eva(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Eva models.') + + model = build_model_with_cfg( + Eva, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + + 'eva_giant_patch14_224.clip_ft_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', + hf_hub_id='timm/', + ), + 'eva_giant_patch14_336.clip_ft_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', + hf_hub_id='timm/', + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'), + + 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', + hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', + ), + 'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', + hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', + ), + 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', + hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', + ), + + 'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', + hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt', + input_size=(3, 336, 336), crop_pct=1.0, + ), + 'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', + hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt', + input_size=(3, 336, 336), crop_pct=1.0, + ), + 'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', + hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, + ), + 'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', + hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, + ), + 'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', + hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, + ), + + 'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841, + ), + 'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841, + ), + 'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841, + ), + + 'eva02_tiny_patch14_224.mim_in22k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt', + num_classes=0, + ), + 'eva02_small_patch14_224.mim_in22k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt', + num_classes=0, + ), + 'eva02_base_patch14_224.mim_in22k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt', + num_classes=0, + ), + 'eva02_large_patch14_224.mim_in22k': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt', + num_classes=0, + ), + 'eva02_large_patch14_224.mim_m38m': _cfg( + hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt', + num_classes=0, + ), + +}) + + +@register_model +def eva_giant_patch14_224(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_giant_patch14_336(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva_giant_patch14_560(pretrained=False, **kwargs): + """ EVA-g model https://arxiv.org/abs/2211.07636 """ + model_kwargs = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, **kwargs) + model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def eva02_tiny_patch14_224(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=224, + patch_size=14, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4 * 2 / 3, + swiglu_mlp=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + +@register_model +def eva02_small_patch14_224(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=224, + patch_size=14, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4 * 2 / 3, + swiglu_mlp=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + +@register_model +def eva02_base_patch14_224(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=224, + patch_size=14, + embed_dim=768, + depth=12, + num_heads=12, + qkv_fused=False, + mlp_ratio=4 * 2 / 3, + scale_mlp=True, + swiglu_mlp=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + +@register_model +def eva02_large_patch14_224(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=224, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4 * 2 / 3, + qkv_fused=False, + scale_mlp=True, + swiglu_mlp=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + +@register_model +def eva02_tiny_patch14_336(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=336, + patch_size=14, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4 * 2 / 3, + swiglu_mlp=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + +@register_model +def eva02_small_patch14_336(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=336, + patch_size=14, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4 * 2 / 3, + swiglu_mlp=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + +@register_model +def eva02_base_patch14_448(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=448, + patch_size=14, + embed_dim=768, + depth=12, + num_heads=12, + qkv_fused=False, + mlp_ratio=4 * 2 / 3, + scale_mlp=True, + swiglu_mlp=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + + +@register_model +def eva02_large_patch14_448(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=448, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4 * 2 / 3, + qkv_fused=False, + scale_mlp=True, + swiglu_mlp=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model