Merge branch 'main' into fast_model

This commit is contained in:
Ryan 2025-05-17 22:25:34 +08:00
commit bea1137583
10 changed files with 976 additions and 162 deletions

View File

@ -49,7 +49,7 @@ parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN
help='Override mean pixel value of dataset') help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset') help='Override std deviation of of dataset')
parser.add_argument('--num-classes', type=int, default=1000, parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset') help='Number classes in dataset')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to checkpoint (default: none)') help='path to checkpoint (default: none)')

View File

@ -53,7 +53,7 @@ FEAT_INTER_FILTERS = [
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet',

View File

@ -1,6 +1,7 @@
from .activations import * from .activations import *
from .adaptive_avgmax_pool import \ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention import Attention, AttentionRope
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
from .attention_pool import AttentionPoolLatent from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
@ -41,6 +42,7 @@ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct,
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool1d import global_pool_nlc
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \

212
timm/layers/attention.py Normal file
View File

@ -0,0 +1,212 @@
from typing import Final, Optional, Type
import torch
from torch import nn as nn
from torch.nn import functional as F
from .config import use_fused_attn
from .pos_embed_sincos import apply_rot_embed_cat
class Attention(nn.Module):
"""Standard Multi-head Self Attention module with QKV projection.
This module implements the standard multi-head attention mechanism used in transformers.
It supports both the fused attention implementation (scaled_dot_product_attention) for
efficiency when available, and a manual implementation otherwise. The module includes
options for QK normalization, attention dropout, and projection dropout.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
"""Initialize the Attention module.
Args:
dim: Input dimension of the token embeddings
num_heads: Number of attention heads
qkv_bias: Whether to use bias in the query, key, value projections
qk_norm: Whether to apply normalization to query and key vectors
proj_bias: Whether to use bias in the output projection
attn_drop: Dropout rate applied to the attention weights
proj_drop: Dropout rate applied after the output projection
norm_layer: Normalization layer constructor for QK normalization if enabled
"""
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionRope(nn.Module):
""" A Self Attention module with ROPE support.
Includes options for:
* QK normalization option
* Attention output (scale) normalization
* Fused or unfused QKV projection support
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qkv_fused: bool = True,
num_prefix_tokens: int = 1,
attn_drop: float = 0.,
proj_drop: float = 0.,
attn_head_dim: Optional[int] = None,
norm_layer: Type[nn.Module] = None,
qk_norm: bool = False,
scale_norm: bool = False,
):
"""Initialize the Attention module.
Args:
dim: Input dimension of the token embeddings
num_heads: Number of attention heads
qkv_bias: Whether to add a bias term to the query, key, and value projections
num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
should not have position embeddings applied
attn_drop: Dropout rate for attention weights
proj_drop: Dropout rate for the output projection
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
norm_layer: Normalization layer constructor to use for QK and scale normalization
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
scale_norm: Enable normalization (scaling) of attention output with norm_layer
"""
super().__init__()
if scale_norm or qk_norm:
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
attn_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.num_prefix_tokens = num_prefix_tokens
self.fused_attn = use_fused_attn()
if qkv_fused:
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
self.q_proj = self.k_proj = self.v_proj = None
else:
self.qkv = None
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x,
rope: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
"""Forward pass for the attention module.
Args:
x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
rope: Rotary position embeddings tensor for position-aware attention
attn_mask: Optional attention mask to apply during attention computation
Returns:
Tensor of shape (batch_size, sequence_length, embedding_dim)
"""
B, N, C = x.shape
if self.qkv is not None:
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
else:
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
q, k = self.q_norm(q), self.k_norm(k)
if rope is not None:
npt = self.num_prefix_tokens
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if attn_mask is not None:
attn_mask = attn_mask.to(torch.bool)
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.norm(x)
x = self.proj(x)
x = self.proj_drop(x)
return x

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -28,8 +28,8 @@ class AttentionPoolLatent(nn.Module):
latent_dim: int = None, latent_dim: int = None,
pos_embed: str = '', pos_embed: str = '',
pool_type: str = 'token', pool_type: str = 'token',
norm_layer: Optional[nn.Module] = None, norm_layer: Optional[Type[nn.Module]] = None,
act_layer: Optional[nn.Module] = nn.GELU, act_layer: Optional[Type[nn.Module]] = nn.GELU,
drop: float = 0.0, drop: float = 0.0,
): ):
super().__init__() super().__init__()

26
timm/layers/pool1d.py Normal file
View File

@ -0,0 +1,26 @@
import torch
def global_pool_nlc(
x: torch.Tensor,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
):
if not pool_type:
return x
if pool_type == 'token':
x = x[:, 0] # class token
else:
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
if pool_type == 'avg':
x = x.mean(dim=1)
elif pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
elif pool_type == 'max':
x = x.amax(dim=1)
else:
assert not pool_type, f'Unknown pool type {pool_type}'
return x

View File

@ -87,6 +87,8 @@ def build_fourier_pos_embed(
include_grid: bool = False, include_grid: bool = False,
in_pixels: bool = True, in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -102,6 +104,8 @@ def build_fourier_pos_embed(
include_grid: Include the spatial grid in output. include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq. in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune. ref_feat_shape: Reference feature shape for resize / fine-tune.
grid_offset: Constant offset to add to grid for non-pixel freq.
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
dtype: Output dtype. dtype: Output dtype.
device: Output device. device: Output device.
@ -130,15 +134,21 @@ def build_fourier_pos_embed(
dtype = bands.dtype dtype = bands.dtype
if in_pixels: if in_pixels:
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape] t = [
torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32)
for s in feat_shape
]
else: else:
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape] t = [
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + grid_offset
for s in feat_shape
]
if ref_feat_shape is not None: if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain) # eva's scheme for resizing rope embeddings (ref shape = pretrain)
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
grid = torch.stack(ndgrid(t), dim=-1) grid = torch.stack(torch.meshgrid(t, indexing=grid_indexing), dim=-1)
grid = grid.unsqueeze(-1) grid = grid.unsqueeze(-1)
pos = grid * bands pos = grid * bands
@ -229,6 +239,8 @@ def build_rotary_pos_embed(
linear_bands: bool = False, linear_bands: bool = False,
in_pixels: bool = True, in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
): ):
@ -242,6 +254,9 @@ def build_rotary_pos_embed(
temperature: Temperature (inv freq) for non-pixel mode temperature: Temperature (inv freq) for non-pixel mode
linear_bands: Linearly (instead of log) spaced bands for pixel mode linear_bands: Linearly (instead of log) spaced bands for pixel mode
in_pixels: Pixel vs language (inv freq) mode. in_pixels: Pixel vs language (inv freq) mode.
ref_feat_shape: Reference feature shape for resize / fine-tune.
grid_offset: Constant offset to add to grid for non-pixel freq.
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
dtype: Output dtype. dtype: Output dtype.
device: Output device. device: Output device.
@ -257,6 +272,8 @@ def build_rotary_pos_embed(
linear_bands=linear_bands, linear_bands=linear_bands,
in_pixels=in_pixels, in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape, ref_feat_shape=ref_feat_shape,
grid_offset=grid_offset,
grid_indexing=grid_indexing,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -289,6 +306,8 @@ class RotaryEmbedding(nn.Module):
linear_bands: bool = False, linear_bands: bool = False,
feat_shape: Optional[List[int]] = None, feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -297,6 +316,8 @@ class RotaryEmbedding(nn.Module):
self.in_pixels = in_pixels self.in_pixels = in_pixels
self.feat_shape = feat_shape self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape self.ref_feat_shape = ref_feat_shape
self.grid_offset = grid_offset
self.grid_indexing = grid_indexing
if feat_shape is None: if feat_shape is None:
# only cache bands # only cache bands
@ -328,6 +349,8 @@ class RotaryEmbedding(nn.Module):
linear_bands=linear_bands, linear_bands=linear_bands,
in_pixels=in_pixels, in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape, ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
) )
self.bands = None self.bands = None
self.register_buffer( self.register_buffer(
@ -349,6 +372,9 @@ class RotaryEmbedding(nn.Module):
shape, shape,
self.bands, self.bands,
in_pixels=self.in_pixels, in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
) )
else: else:
return self.pos_embed_sin, self.pos_embed_cos return self.pos_embed_sin, self.pos_embed_cos
@ -376,6 +402,8 @@ class RotaryEmbeddingCat(nn.Module):
linear_bands: bool = False, linear_bands: bool = False,
feat_shape: Optional[List[int]] = None, feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -384,6 +412,8 @@ class RotaryEmbeddingCat(nn.Module):
self.in_pixels = in_pixels self.in_pixels = in_pixels
self.feat_shape = feat_shape self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape self.ref_feat_shape = ref_feat_shape
self.grid_offset = grid_offset
self.grid_indexing = grid_indexing
if feat_shape is None: if feat_shape is None:
# only cache bands # only cache bands
@ -414,6 +444,8 @@ class RotaryEmbeddingCat(nn.Module):
linear_bands=linear_bands, linear_bands=linear_bands,
in_pixels=in_pixels, in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape, ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
) )
self.bands = None self.bands = None
self.register_buffer( self.register_buffer(
@ -430,6 +462,8 @@ class RotaryEmbeddingCat(nn.Module):
self.bands, self.bands,
in_pixels=self.in_pixels, in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape, ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
) )
return torch.cat(embeds, -1) return torch.cat(embeds, -1)
elif self.pos_embed is not None: elif self.pos_embed is not None:

View File

@ -25,6 +25,7 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
# EVA models Copyright (c) 2022 BAAI-Vision # EVA models Copyright (c) 2022 BAAI-Vision
# EVA02 models Copyright (c) 2023 BAAI-Vision # EVA02 models Copyright (c) 2023 BAAI-Vision
import math import math
from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
@ -34,7 +35,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \ apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \
to_2tuple, use_fused_attn global_pool_nlc, to_2tuple, use_fused_attn, AttentionRope, AttentionPoolLatent
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
@ -45,6 +46,8 @@ __all__ = ['Eva']
class EvaAttention(nn.Module): class EvaAttention(nn.Module):
""" EVA Attention with ROPE, no k-bias, and fused/unfused qkv options
"""
fused_attn: torch.jit.Final[bool] fused_attn: torch.jit.Final[bool]
def __init__( def __init__(
@ -53,55 +56,64 @@ class EvaAttention(nn.Module):
num_heads: int = 8, num_heads: int = 8,
qkv_bias: bool = True, qkv_bias: bool = True,
qkv_fused: bool = True, qkv_fused: bool = True,
num_prefix_tokens: int = 1,
qkv_bias_separate: bool = False, qkv_bias_separate: bool = False,
num_prefix_tokens: int = 1,
attn_drop: float = 0., attn_drop: float = 0.,
proj_drop: float = 0., proj_drop: float = 0.,
attn_head_dim: Optional[int] = None, attn_head_dim: Optional[int] = None,
norm_layer: Optional[Callable] = None, norm_layer: Optional[Callable] = None,
qk_norm: bool = False,
scale_norm: bool = True,
): ):
""" """
Args: Args:
dim: dim: Input dimension of the token embeddings
num_heads: num_heads: Number of attention heads
qkv_bias: qkv_bias: Whether to add a bias term to the query, key, and value projections
qkv_fused: qkv_fused: Whether qkv projections are fused into one projection or separate
attn_drop: qkv_bias_separate: Whether to apply bias to qkv as a separate addition or part of F.linear() call
proj_drop: num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
attn_head_dim: should not have position embeddings applied
norm_layer: attn_drop: Dropout rate for attention weights
proj_drop: Dropout rate for the output projection
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
norm_layer: Normalization layer constructor to use for QK and scale normalization
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
scale_norm: Enable normalization (scaling) of attention output with norm_layer
""" """
super().__init__() super().__init__()
if scale_norm or qk_norm:
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
if attn_head_dim is not None: if attn_head_dim is not None:
head_dim = attn_head_dim head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads attn_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5 self.scale = head_dim ** -0.5
self.num_prefix_tokens = num_prefix_tokens self.num_prefix_tokens = num_prefix_tokens
self.fused_attn = use_fused_attn() self.fused_attn = use_fused_attn()
self.qkv_bias_separate = qkv_bias_separate self.qkv_bias_separate = qkv_bias_separate
if qkv_fused: if qkv_fused:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) self.qkv = nn.Linear(dim, attn_dim * 3, bias=False)
self.q_proj = self.k_proj = self.v_proj = None self.q_proj = self.k_proj = self.v_proj = None
if qkv_bias: if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.q_bias = nn.Parameter(torch.zeros(attn_dim))
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) self.register_buffer('k_bias', torch.zeros(attn_dim), persistent=False)
self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(attn_dim))
else: else:
self.q_bias = self.k_bias = self.v_bias = None self.q_bias = self.k_bias = self.v_bias = None
else: else:
self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias) self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False) self.k_proj = nn.Linear(dim, attn_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias) self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.qkv = None self.qkv = None
self.q_bias = self.k_bias = self.v_bias = None self.q_bias = self.k_bias = self.v_bias = None
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(all_head_dim) if norm_layer is not None else nn.Identity() self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
self.proj = nn.Linear(all_head_dim, dim) self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
def forward( def forward(
@ -110,6 +122,16 @@ class EvaAttention(nn.Module):
rope: Optional[torch.Tensor] = None, rope: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
): ):
"""Forward pass for the attention module.
Args:
x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
rope: Rotary position embeddings tensor for position-aware attention
attn_mask: Optional attention mask to apply during attention computation
Returns:
Tensor of shape (batch_size, sequence_length, embedding_dim)
"""
B, N, C = x.shape B, N, C = x.shape
if self.qkv is not None: if self.qkv is not None:
@ -129,6 +151,8 @@ class EvaAttention(nn.Module):
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
q, k = self.q_norm(q), self.k_norm(k)
if rope is not None: if rope is not None:
npt = self.num_prefix_tokens npt = self.num_prefix_tokens
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v) q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
@ -172,6 +196,7 @@ class EvaBlock(nn.Module):
scale_mlp: bool = False, scale_mlp: bool = False,
scale_attn_inner: bool = False, scale_attn_inner: bool = False,
num_prefix_tokens: int = 1, num_prefix_tokens: int = 1,
attn_type: str = 'eva',
proj_drop: float = 0., proj_drop: float = 0.,
attn_drop: float = 0., attn_drop: float = 0.,
drop_path: float = 0., drop_path: float = 0.,
@ -180,28 +205,31 @@ class EvaBlock(nn.Module):
norm_layer: Callable = LayerNorm, norm_layer: Callable = LayerNorm,
attn_head_dim: Optional[int] = None, attn_head_dim: Optional[int] = None,
): ):
""" """ Initialize the EVA transformer block.
Args: Args:
dim: dim: Input dimension of the token embeddings
num_heads: num_heads: Number of attention heads
qkv_bias: qkv_bias: Whether to use bias terms in query, key, value projections
qkv_fused: qkv_fused: Whether to use a single projection for query, key, value
mlp_ratio: mlp_ratio: Ratio of MLP hidden dimension to input dimension
swiglu_mlp: swiglu_mlp: Whether to use SwiGLU activation in the MLP
scale_mlp: scale_mlp: Whether to use normalization in the MLP
scale_attn_inner: scale_attn_inner: Whether to use normalization within the attention mechanism
proj_drop: num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.)
attn_drop: attn_type: Type of attention module to use ('eva' or 'rope')
drop_path: proj_drop: Dropout rate for projection layers
init_values: attn_drop: Dropout rate for attention matrix
act_layer: drop_path: Stochastic depth rate
norm_layer: init_values: Initial value for LayerScale, None = no LayerScale
attn_head_dim: act_layer: Activation layer constructor
norm_layer: Normalization layer constructor
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
""" """
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = EvaAttention( attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
self.attn = attn_cls(
dim, dim,
num_heads=num_heads, num_heads=num_heads,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
@ -210,7 +238,8 @@ class EvaBlock(nn.Module):
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=proj_drop, proj_drop=proj_drop,
attn_head_dim=attn_head_dim, attn_head_dim=attn_head_dim,
norm_layer=norm_layer if scale_attn_inner else None, norm_layer=norm_layer,
scale_norm=scale_attn_inner,
) )
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -266,6 +295,7 @@ class EvaBlockPostNorm(nn.Module):
qkv_bias: bool = True, qkv_bias: bool = True,
qkv_fused: bool = True, qkv_fused: bool = True,
mlp_ratio: float = 4., mlp_ratio: float = 4.,
attn_type: str = 'eva',
swiglu_mlp: bool = False, swiglu_mlp: bool = False,
scale_mlp: bool = False, scale_mlp: bool = False,
scale_attn_inner: bool = False, scale_attn_inner: bool = False,
@ -278,27 +308,30 @@ class EvaBlockPostNorm(nn.Module):
norm_layer: Callable = nn.LayerNorm, norm_layer: Callable = nn.LayerNorm,
attn_head_dim: Optional[int] = None, attn_head_dim: Optional[int] = None,
): ):
""" """ Initialize the post-norm EVA transformer block.
Args: Args:
dim: dim: Input dimension of the token embeddings
num_heads: num_heads: Number of attention heads
qkv_bias: qkv_bias: Whether to use bias terms in query, key, value projections
qkv_fused: qkv_fused: Whether to use a single projection for query, key, value
mlp_ratio: mlp_ratio: Ratio of MLP hidden dimension to input dimension
swiglu_mlp: swiglu_mlp: Whether to use SwiGLU activation in the MLP
scale_mlp: scale_mlp: Whether to use normalization in the MLP
scale_attn_inner: scale_attn_inner: Whether to use normalization within the attention mechanism
proj_drop: num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.)
attn_drop: attn_type: Type of attention module to use ('eva' or 'rope')
drop_path: proj_drop: Dropout rate for projection layers
init_values: attn_drop: Dropout rate for attention matrix
act_layer: drop_path: Stochastic depth rate
norm_layer: init_values: Initial value for LayerScale, None = no LayerScale (NOTE: ignored for post-norm block)
attn_head_dim: act_layer: Activation layer constructor
norm_layer: Normalization layer constructor
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
""" """
super().__init__() super().__init__()
self.attn = EvaAttention( attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
self.attn = attn_cls(
dim, dim,
num_heads=num_heads, num_heads=num_heads,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
@ -307,7 +340,8 @@ class EvaBlockPostNorm(nn.Module):
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=proj_drop, proj_drop=proj_drop,
attn_head_dim=attn_head_dim, attn_head_dim=attn_head_dim,
norm_layer=norm_layer if scale_attn_inner else None, norm_layer=norm_layer,
scale_norm=scale_attn_inner,
) )
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -373,6 +407,7 @@ class Eva(nn.Module):
swiglu_mlp: bool = False, swiglu_mlp: bool = False,
scale_mlp: bool = False, scale_mlp: bool = False,
scale_attn_inner: bool = False, scale_attn_inner: bool = False,
attn_type: str = 'eva',
drop_rate: float = 0., drop_rate: float = 0.,
pos_drop_rate: float = 0., pos_drop_rate: float = 0.,
patch_drop_rate: float = 0., patch_drop_rate: float = 0.,
@ -385,44 +420,64 @@ class Eva(nn.Module):
num_reg_tokens: int = 0, num_reg_tokens: int = 0,
use_abs_pos_emb: bool = True, use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False, use_rot_pos_emb: bool = False,
rope_grid_offset: float = 0.,
rope_grid_indexing: str = 'ij',
use_post_norm: bool = False, use_post_norm: bool = False,
use_pre_transformer_norm: bool = False,
use_post_transformer_norm: Optional[bool] = None,
use_fc_norm: Optional[bool] = None,
attn_pool_num_heads: Optional[int] = None,
attn_pool_mlp_ratio: Optional[float] = None,
dynamic_img_size: bool = False, dynamic_img_size: bool = False,
dynamic_img_pad: bool = False, dynamic_img_pad: bool = False,
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
head_init_scale: float = 0.001, head_init_scale: float = 0.001,
): ):
""" """Initialize the EVA Vision Transformer model.
Args: Args:
img_size: img_size: Input image size (single int for square, or tuple for rectangular)
patch_size: patch_size: Patch size to divide image into tokens (single int for square, or tuple)
in_chans: in_chans: Number of input image channels
num_classes: num_classes: Number of classes (output dim) for classification head (final projection), 0 for pass-through
global_pool: global_pool: Type of global pooling for final sequence ('avg', 'token', 'map', etc.)
embed_dim: embed_dim: Embedding dimension for tokens
depth: depth: Number of transformer blocks
num_heads: num_heads: Number of attention heads
qkv_bias: qkv_bias: Enable bias for query, key, value projections
qkv_fused: qkv_fused: Use a single projection for query, key, value
mlp_ratio: mlp_ratio: Ratio of mlp hidden dim to embedding dim
swiglu_mlp: swiglu_mlp: Use SwiGLU activation in MLP
scale_mlp: scale_mlp: Apply scaling normalization in MLP (normformer style)
scale_attn_inner: scale_attn_inner: Apply scaling normalization inside attention
drop_rate: attn_type: Type of attention module to use
pos_drop_rate: drop_rate: Dropout rate after final projection and pooling
proj_drop_rate: pos_drop_rate: Dropout rate for positional embeddings
attn_drop_rate: patch_drop_rate: Rate of dropping patches during training
drop_path_rate: proj_drop_rate: Dropout rate for projections
norm_layer: attn_drop_rate: Dropout rate for attention
init_values: drop_path_rate: Stochastic depth rate
class_token: norm_layer: Normalization layer constructor
use_abs_pos_emb: init_values: Initial layer-scale values
use_rot_pos_emb: class_token: Use class token
use_post_norm: num_reg_tokens: Number of additional learnable 'register' tokens to add to the sequence
ref_feat_shape: use_abs_pos_emb: Use absolute (learned) positional embeddings
head_init_scale: use_rot_pos_emb: Use rotary position embeddings
rope_grid_offset: Offset for rotary position embedding grid
rope_grid_indexing: Indexing mode for rotary position embeddings ('ij' or 'xy')
use_post_norm: Use post-norm transformer block type
use_pre_transformer_norm: Use normalization layer before transformer blocks
use_post_transformer_norm: Use normalization layer after transformer blocks
use_fc_norm: Use normalization layer after pooling, before final classifier
attn_pool_num_heads: Number of heads in attention pooling
attn_pool_mlp_ratio: MLP ratio in attention pooling
dynamic_img_size: Support dynamic image sizes in forward pass
dynamic_img_pad: Apply dynamic padding for irregular image sizes
ref_feat_shape: Reference feature shape for rotary position embedding scale
head_init_scale: Initialization scale for classification head weights
""" """
super().__init__() super().__init__()
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
@ -430,6 +485,17 @@ class Eva(nn.Module):
self.dynamic_img_size = dynamic_img_size self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False self.grad_checkpointing = False
# resolve norm / pool usage
activate_pre_norm = use_pre_transformer_norm
if use_fc_norm is not None:
activate_fc_norm = use_fc_norm # pass through if explicit
else:
activate_fc_norm = global_pool == 'avg' # default on if avg pool used
if use_post_transformer_norm is not None:
activate_post_norm = use_post_transformer_norm # pass through if explicit
else:
activate_post_norm = not activate_fc_norm # default on if fc_norm isn't active
embed_args = {} embed_args = {}
if dynamic_img_size: if dynamic_img_size:
# flatten deferred until after pos embed # flatten deferred until after pos embed
@ -440,6 +506,7 @@ class Eva(nn.Module):
in_chans=in_chans, in_chans=in_chans,
embed_dim=embed_dim, embed_dim=embed_dim,
dynamic_img_pad=dynamic_img_pad, dynamic_img_pad=dynamic_img_pad,
bias=not use_pre_transformer_norm,
**embed_args, **embed_args,
) )
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
@ -468,10 +535,14 @@ class Eva(nn.Module):
in_pixels=False, in_pixels=False,
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
ref_feat_shape=ref_feat_shape, ref_feat_shape=ref_feat_shape,
grid_offset=rope_grid_offset,
grid_indexing=rope_grid_indexing,
) )
else: else:
self.rope = None self.rope = None
self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
@ -484,6 +555,7 @@ class Eva(nn.Module):
swiglu_mlp=swiglu_mlp, swiglu_mlp=swiglu_mlp,
scale_mlp=scale_mlp, scale_mlp=scale_mlp,
scale_attn_inner=scale_attn_inner, scale_attn_inner=scale_attn_inner,
attn_type=attn_type,
num_prefix_tokens=self.num_prefix_tokens, num_prefix_tokens=self.num_prefix_tokens,
proj_drop=proj_drop_rate, proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
@ -495,9 +567,21 @@ class Eva(nn.Module):
self.feature_info = [ self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
use_fc_norm = self.global_pool == 'avg' self.norm = norm_layer(embed_dim) if activate_post_norm else nn.Identity()
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() if global_pool == 'map':
attn_pool_num_heads = attn_pool_num_heads or num_heads
attn_pool_mlp_ratio = attn_pool_mlp_ratio or mlp_ratio
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=attn_pool_num_heads,
mlp_ratio=attn_pool_mlp_ratio,
norm_layer=norm_layer,
act_layer=nn.GELU,
)
else:
self.attn_pool = None
self.fc_norm = norm_layer(embed_dim) if activate_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate) self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -626,6 +710,7 @@ class Eva(nn.Module):
B, _, height, width = x.shape B, _, height, width = x.shape
x = self.patch_embed(x) x = self.patch_embed(x)
x, rot_pos_embed = self._pos_embed(x) x, rot_pos_embed = self._pos_embed(x)
x = self.norm_pre(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks blocks = self.blocks
else: else:
@ -668,13 +753,23 @@ class Eva(nn.Module):
if prune_norm: if prune_norm:
self.norm = nn.Identity() self.norm = nn.Identity()
if prune_head: if prune_head:
self.attn_pool = None
self.fc_norm = nn.Identity() self.fc_norm = nn.Identity()
self.reset_classifier(0, '') self.reset_classifier(0, '')
return take_indices return take_indices
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
if self.attn_pool is not None:
x = self.attn_pool(x)
return x
pool_type = self.global_pool if pool_type is None else pool_type
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
return x
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x, rot_pos_embed = self._pos_embed(x) x, rot_pos_embed = self._pos_embed(x)
x = self.norm_pre(x)
for blk in self.blocks: for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, rope=rot_pos_embed) x = checkpoint(blk, x, rope=rot_pos_embed)
@ -684,8 +779,7 @@ class Eva(nn.Module):
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.global_pool: x = self.pool(x)
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x) x = self.fc_norm(x)
x = self.head_drop(x) x = self.head_drop(x)
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
@ -696,6 +790,67 @@ class Eva(nn.Module):
return x return x
def _convert_pe(
state_dict,
model,
prefix: str = 'visual.',
):
""" Convert Perception Encoder weights """
state_dict = state_dict.get('model', state_dict)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
out_dict = {}
swaps = [
('conv1', 'patch_embed.proj'),
('positional_embedding', 'pos_embed'),
('transformer.resblocks.', 'blocks.'),
('ln_pre', 'norm_pre'),
('ln_post', 'norm'),
('ln_', 'norm'),
('ls_1.gamma', 'gamma_1'),
('ls_2.gamma', 'gamma_2'),
('in_proj_', 'qkv.'),
('out_proj', 'proj'),
('mlp.c_fc', 'mlp.fc1'),
('mlp.c_proj', 'mlp.fc2'),
]
len_prefix = len(prefix)
for k, v in state_dict.items():
if prefix:
if not k.startswith(prefix):
continue
k = k[len_prefix:]
for sp in swaps:
k = k.replace(sp[0], sp[1])
if k.startswith('attn_pool'):
k = k.replace('attn_pool.attn', 'attn_pool')
k = k.replace('attn_pool.layernorm', 'attn_pool.norm')
k = k.replace('attn_pool.probe', 'attn_pool.latent')
if k.startswith('attn_pool.qkv'):
dim = v.shape[0] // 3
if k.endswith('weight'):
out_dict['attn_pool.q.weight'] = v[:dim]
out_dict['attn_pool.kv.weight'] = v[dim:]
elif k.endswith('bias'):
out_dict['attn_pool.q.bias'] = v[:dim]
out_dict['attn_pool.kv.bias'] = v[dim:]
continue
elif k == 'proj':
k = 'head.weight'
v = v.transpose(0, 1)
out_dict['head.bias'] = torch.zeros(v.shape[0])
elif k == 'class_embedding':
k = 'cls_token'
v = v.unsqueeze(0).unsqueeze(1)
elif k == 'pos_embed':
v = v.unsqueeze(0)
out_dict[k] = v
return out_dict
def checkpoint_filter_fn( def checkpoint_filter_fn(
state_dict, state_dict,
model, model,
@ -708,6 +863,13 @@ def checkpoint_filter_fn(
state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('module', state_dict) state_dict = state_dict.get('module', state_dict)
state_dict = state_dict.get('state_dict', state_dict) state_dict = state_dict.get('state_dict', state_dict)
# Loading Meta PE (Perception Encoder) weights
if 'visual.conv1.weight' in state_dict:
return _convert_pe(state_dict, model)
elif 'conv1.weight' in state_dict:
return _convert_pe(state_dict, model, prefix='')
# prefix for loading OpenCLIP compatible weights # prefix for loading OpenCLIP compatible weights
if 'visual.trunk.pos_embed' in state_dict: if 'visual.trunk.pos_embed' in state_dict:
prefix = 'visual.trunk.' prefix = 'visual.trunk.'
@ -721,10 +883,9 @@ def checkpoint_filter_fn(
len_prefix = len(prefix) len_prefix = len(prefix)
for k, v in state_dict.items(): for k, v in state_dict.items():
if prefix: if prefix:
if k.startswith(prefix): if not k.startswith(prefix):
k = k[len_prefix:]
else:
continue continue
k = k[len_prefix:]
if 'rope' in k: if 'rope' in k:
# fixed embedding no need to load buffer from checkpoint # fixed embedding no need to load buffer from checkpoint
@ -797,6 +958,17 @@ def _cfg(url='', **kwargs):
} }
def _pe_cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 0, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.proj', 'classifier': 'head',
'license': 'custom', **kwargs
}
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
# EVA 01 CLIP fine-tuned on imagenet-1k # EVA 01 CLIP fine-tuned on imagenet-1k
@ -984,6 +1156,49 @@ default_cfgs = generate_default_cfgs({
input_size=(3, 256, 256), crop_pct=0.95, input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
), ),
# Perception Encoder weights
'vit_pe_core_base_patch16_224': _pe_cfg(
#hf_hub_id='facebook/pe_core_base_patch16_224_timm',
hf_hub_id='facebook/PE-Core-B16-224',
hf_hub_filename='PE-Core-B16-224.pt',
input_size=(3, 224, 224),
num_classes=1024, # output proj dim
),
'vit_pe_core_large_patch14_336': _pe_cfg(
hf_hub_id='facebook/PE-Core-L14-336',
hf_hub_filename='PE-Core-L14-336.pt',
input_size=(3, 336, 336),
num_classes=1024, # output proj dim
),
'vit_pe_core_gigantic_patch14_448': _pe_cfg(
#hf_hub_id='timm/',
hf_hub_id='facebook/PE-Core-G14-448',
hf_hub_filename='PE-Core-G14-448.pt',
input_size=(3, 448, 448),
num_classes=1280, # output proj dim
),
'vit_pe_lang_large_patch14_448': _pe_cfg(
#hf_hub_id='timm/',
hf_hub_id='facebook/PE-Lang-L14-448',
hf_hub_filename='PE-Lang-L14-448.pt',
input_size=(3, 448, 448),
num_classes=0,
),
'vit_pe_lang_gigantic_patch14_448': _pe_cfg(
#hf_hub_id='timm/',
hf_hub_id='facebook/PE-Lang-G14-448',
hf_hub_filename='PE-Lang-G14-448.pt',
input_size=(3, 448, 448),
num_classes=0,
),
'vit_pe_spatial_gigantic_patch14_448': _pe_cfg(
#hf_hub_id='timm/',
hf_hub_id='facebook/PE-Spatial-G14-448',
hf_hub_filename='PE-Spatial-G14-448.pt',
input_size=(3, 448, 448),
num_classes=0,
),
}) })
@ -1330,3 +1545,142 @@ def vit_base_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva:
) )
model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model
def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
model_args = dict(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
global_pool='map',
attn_type='rope',
use_pre_transformer_norm=True,
use_rot_pos_emb=True,
ref_feat_shape=(14, 14),
rope_grid_offset=1.,
rope_grid_indexing='xy',
attn_pool_num_heads=8,
attn_pool_mlp_ratio=4.,
norm_layer=partial(LayerNorm, eps=1e-5),
#dynamic_img_size=True
)
return _create_eva('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
model_args = dict(
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.0,
global_pool='map',
attn_type='rope',
use_pre_transformer_norm=True,
use_rot_pos_emb=True,
ref_feat_shape=(24, 24),
rope_grid_offset=1.,
rope_grid_indexing='xy',
attn_pool_num_heads=8,
attn_pool_mlp_ratio=4.,
norm_layer=partial(LayerNorm, eps=1e-5),
#dynamic_img_size=True,
)
return _create_eva('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict(
patch_size=14,
embed_dim=1536,
depth=50,
num_heads=16,
mlp_ratio=8960 / 1536,
global_pool='map',
attn_type='rope',
class_token=False,
use_pre_transformer_norm=True,
use_rot_pos_emb=True,
ref_feat_shape=(32, 32),
rope_grid_indexing='xy',
attn_pool_num_heads=8,
attn_pool_mlp_ratio=4.,
norm_layer=partial(LayerNorm, eps=1e-5),
#dynamic_img_size=True,
)
return _create_eva('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
model_args = dict(
patch_size=14,
embed_dim=1024,
depth=23,
num_heads=16,
mlp_ratio=4.0,
attn_type='rope',
class_token=True,
use_rot_pos_emb=True,
ref_feat_shape=(32, 32),
rope_grid_offset=1.,
rope_grid_indexing='xy',
use_pre_transformer_norm=True,
use_post_transformer_norm=False,
use_fc_norm=False, # explicitly disable
init_values=0.1,
norm_layer=partial(LayerNorm, eps=1e-5),
#dynamic_img_size=True,
)
return _create_eva('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict(
patch_size=14,
embed_dim=1536,
depth=47,
num_heads=16,
mlp_ratio=8960 / 1536,
attn_type='rope',
class_token=False,
use_rot_pos_emb=True,
ref_feat_shape=(32, 32),
rope_grid_indexing='xy',
use_pre_transformer_norm=True,
use_post_transformer_norm=False,
use_fc_norm=False, # explicitly disable
init_values=0.1,
norm_layer=partial(LayerNorm, eps=1e-5),
#dynamic_img_size=True,
)
return _create_eva('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict(
patch_size=14,
embed_dim=1536,
depth=50,
num_heads=16,
mlp_ratio=8960 / 1536,
attn_type='rope',
class_token=False,
use_rot_pos_emb=True,
ref_feat_shape=(32, 32),
rope_grid_indexing='xy',
use_pre_transformer_norm=True,
use_post_transformer_norm=False,
init_values=0.1,
norm_layer=partial(LayerNorm, eps=1e-5),
#dynamic_img_size=True,
)
return _create_eva('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -5,48 +5,30 @@ A PyTorch implement of TNT as described in
The official mindspore code is released and available at The official mindspore code is released and available at
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
The official pytorch code is released and available at
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
""" """
import math import math
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint from ._manipulate import checkpoint
from ._registry import register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this __all__ = ['TNT'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'tnt_s_patch16_224': _cfg(
url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'tnt_b_patch16_224': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
}
class Attention(nn.Module): class Attention(nn.Module):
""" Multi-Head Attention """ Multi-Head Attention
""" """
def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__() super().__init__()
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
@ -64,7 +46,7 @@ class Attention(nn.Module):
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, C = x.shape
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale attn = (q @ k.transpose(-2, -1)) * self.scale
@ -80,6 +62,7 @@ class Attention(nn.Module):
class Block(nn.Module): class Block(nn.Module):
""" TNT Block """ TNT Block
""" """
def __init__( def __init__(
self, self,
dim, dim,
@ -94,6 +77,7 @@ class Block(nn.Module):
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
legacy=False,
): ):
super().__init__() super().__init__()
# Inner transformer # Inner transformer
@ -106,7 +90,7 @@ class Block(nn.Module):
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=proj_drop, proj_drop=proj_drop,
) )
self.norm_mlp_in = norm_layer(dim) self.norm_mlp_in = norm_layer(dim)
self.mlp_in = Mlp( self.mlp_in = Mlp(
in_features=dim, in_features=dim,
@ -115,9 +99,15 @@ class Block(nn.Module):
act_layer=act_layer, act_layer=act_layer,
drop=proj_drop, drop=proj_drop,
) )
self.legacy = legacy
self.norm1_proj = norm_layer(dim) if self.legacy:
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True) self.norm1_proj = norm_layer(dim)
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
self.norm2_proj = None
else:
self.norm1_proj = norm_layer(dim * num_pixel)
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False)
self.norm2_proj = norm_layer(dim_out)
# Outer transformer # Outer transformer
self.norm_out = norm_layer(dim_out) self.norm_out = norm_layer(dim_out)
@ -130,7 +120,7 @@ class Block(nn.Module):
proj_drop=proj_drop, proj_drop=proj_drop,
) )
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm_mlp = norm_layer(dim_out) self.norm_mlp = norm_layer(dim_out)
self.mlp = Mlp( self.mlp = Mlp(
in_features=dim_out, in_features=dim_out,
@ -146,9 +136,16 @@ class Block(nn.Module):
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
# outer # outer
B, N, C = patch_embed.size() B, N, C = patch_embed.size()
patch_embed = torch.cat( if self.norm2_proj is None:
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))], patch_embed = torch.cat([
dim=1) patch_embed[:, 0:1],
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
], dim=1)
else:
patch_embed = torch.cat([
patch_embed[:, 0:1],
patch_embed[:, 1:] + self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
], dim=1)
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
return pixel_embed, patch_embed return pixel_embed, patch_embed
@ -157,7 +154,16 @@ class Block(nn.Module):
class PixelEmbed(nn.Module): class PixelEmbed(nn.Module):
""" Image to Pixel Embedding """ Image to Pixel Embedding
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
in_dim=48,
stride=4,
legacy=False,
):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
@ -165,23 +171,45 @@ class PixelEmbed(nn.Module):
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
num_patches = (self.grid_size[0]) * (self.grid_size[1]) num_patches = (self.grid_size[0]) * (self.grid_size[1])
self.img_size = img_size self.img_size = img_size
self.patch_size = patch_size
self.legacy = legacy
self.num_patches = num_patches self.num_patches = num_patches
self.in_dim = in_dim self.in_dim = in_dim
new_patch_size = [math.ceil(ps / stride) for ps in patch_size] new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
self.new_patch_size = new_patch_size self.new_patch_size = new_patch_size
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) if self.legacy:
self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
else:
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
def forward(self, x, pixel_pos): def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
if as_scalar:
return max(self.patch_size)
else:
return self.patch_size
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape B, C, H, W = x.shape
_assert(H == self.img_size[0], _assert(
H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1], _assert(
W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x) if self.legacy:
x = self.unfold(x) x = self.proj(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) x = self.unfold(x)
x = x.transpose(1, 2).reshape(
B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
else:
x = self.unfold(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
x = self.proj(x)
x = x + pixel_pos x = x + pixel_pos
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
return x return x
@ -190,6 +218,7 @@ class PixelEmbed(nn.Module):
class TNT(nn.Module): class TNT(nn.Module):
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112 """ Transformer in Transformer - https://arxiv.org/abs/2103.00112
""" """
def __init__( def __init__(
self, self,
img_size=224, img_size=224,
@ -211,12 +240,14 @@ class TNT(nn.Module):
drop_path_rate=0., drop_path_rate=0.,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
first_stride=4, first_stride=4,
legacy=False,
): ):
super().__init__() super().__init__()
assert global_pool in ('', 'token', 'avg') assert global_pool in ('', 'token', 'avg')
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
self.num_prefix_tokens = 1
self.grad_checkpointing = False self.grad_checkpointing = False
self.pixel_embed = PixelEmbed( self.pixel_embed = PixelEmbed(
@ -225,12 +256,14 @@ class TNT(nn.Module):
in_chans=in_chans, in_chans=in_chans,
in_dim=inner_dim, in_dim=inner_dim,
stride=first_stride, stride=first_stride,
legacy=legacy,
) )
num_patches = self.pixel_embed.num_patches num_patches = self.pixel_embed.num_patches
r = self.pixel_embed.feat_ratio() if hasattr(self.pixel_embed, 'feat_ratio') else patch_size
self.num_patches = num_patches self.num_patches = num_patches
new_patch_size = self.pixel_embed.new_patch_size new_patch_size = self.pixel_embed.new_patch_size
num_pixel = new_patch_size[0] * new_patch_size[1] num_pixel = new_patch_size[0] * new_patch_size[1]
self.norm1_proj = norm_layer(num_pixel * inner_dim) self.norm1_proj = norm_layer(num_pixel * inner_dim)
self.proj = nn.Linear(num_pixel * inner_dim, embed_dim) self.proj = nn.Linear(num_pixel * inner_dim, embed_dim)
self.norm2_proj = norm_layer(embed_dim) self.norm2_proj = norm_layer(embed_dim)
@ -255,10 +288,13 @@ class TNT(nn.Module):
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i], drop_path=dpr[i],
norm_layer=norm_layer, norm_layer=norm_layer,
legacy=legacy,
)) ))
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
self.norm = norm_layer(embed_dim) self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
self.norm = norm_layer(embed_dim)
self.head_drop = nn.Dropout(drop_rate) self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -306,20 +342,105 @@ class TNT(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_intermediates(
B = x.shape[0] self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if an int, if is a sequence, select by matching indices
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
pixel_embed = self.pixel_embed(x, self.pixel_pos) pixel_embed = self.pixel_embed(x, self.pixel_pos)
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
patch_embed = patch_embed + self.patch_pos patch_embed = patch_embed + self.patch_pos
patch_embed = self.pos_drop(patch_embed) patch_embed = self.pos_drop(patch_embed)
if self.grad_checkpointing and not torch.jit.is_scripting(): if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
for blk in self.blocks: blocks = self.blocks
pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
else: else:
for blk in self.blocks: blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(patch_embed) if norm else patch_embed)
# process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
if reshape:
# reshape to BCHW output format
H, W = self.pixel_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if intermediates_only:
return intermediates
patch_embed = self.norm(patch_embed)
return patch_embed, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
B = x.shape[0]
pixel_embed = self.pixel_embed(x, self.pixel_pos)
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
patch_embed = patch_embed + self.patch_pos
patch_embed = self.pos_drop(patch_embed)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
else:
pixel_embed, patch_embed = blk(pixel_embed, patch_embed) pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
patch_embed = self.norm(patch_embed) patch_embed = self.norm(patch_embed)
@ -327,7 +448,7 @@ class TNT(nn.Module):
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.global_pool: if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.head_drop(x) x = self.head_drop(x)
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
@ -337,28 +458,92 @@ class TNT(nn.Module):
return x return x
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
'paper_ids': 'arXiv:2103.00112',
'paper_name': 'Transformer in Transformer',
'origin_url': 'https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch',
**kwargs
}
default_cfgs = generate_default_cfgs({
'tnt_s_legacy_patch16_224.in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
),
'tnt_s_patch16_224.in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
),
'tnt_b_patch16_224.in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
),
})
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
state_dict.pop('outer_tokens', None)
if 'patch_pos' in state_dict:
out_dict = state_dict
else:
out_dict = {}
for k, v in state_dict.items():
k = k.replace('outer_pos', 'patch_pos')
k = k.replace('inner_pos', 'pixel_pos')
k = k.replace('patch_embed', 'pixel_embed')
k = k.replace('proj_norm1', 'norm1_proj')
k = k.replace('proj_norm2', 'norm2_proj')
k = k.replace('inner_norm1', 'norm_in')
k = k.replace('inner_attn', 'attn_in')
k = k.replace('inner_norm2', 'norm_mlp_in')
k = k.replace('inner_mlp', 'mlp_in')
k = k.replace('outer_norm1', 'norm_out')
k = k.replace('outer_attn', 'attn_out')
k = k.replace('outer_norm2', 'norm_mlp')
k = k.replace('outer_mlp', 'mlp')
if k == 'pixel_pos' and model.pixel_embed.legacy == False:
B, N, C = v.shape
H = W = int(N ** 0.5)
assert H * W == N
v = v.permute(0, 2, 1).reshape(B, C, H, W)
out_dict[k] = v
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
if state_dict['patch_pos'].shape != model.patch_pos.shape: if out_dict['patch_pos'].shape != model.patch_pos.shape:
state_dict['patch_pos'] = resample_abs_pos_embed( out_dict['patch_pos'] = resample_abs_pos_embed(
state_dict['patch_pos'], out_dict['patch_pos'],
new_size=model.pixel_embed.grid_size, new_size=model.pixel_embed.grid_size,
num_prefix_tokens=1, num_prefix_tokens=1,
) )
return state_dict return out_dict
def _create_tnt(variant, pretrained=False, **kwargs): def _create_tnt(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): out_indices = kwargs.pop('out_indices', 3)
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
TNT, variant, pretrained, TNT, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs) **kwargs)
return model return model
@register_model
def tnt_s_legacy_patch16_224(pretrained=False, **kwargs) -> TNT:
model_cfg = dict(
patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
qkv_bias=False, legacy=True)
model = _create_tnt('tnt_s_legacy_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
return model
@register_model @register_model
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT: def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
model_cfg = dict( model_cfg = dict(

View File

@ -892,6 +892,7 @@ def main():
optimizer, optimizer,
train_loss_fn, train_loss_fn,
args, args,
device=device,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
saver=saver, saver=saver,
output_dir=output_dir, output_dir=output_dir,