pytorch-image-models/timm/models/vision_transformer_flex.py

1148 lines
48 KiB
Python

""" Vision Transformer (New)
An improved version of the Vision Transformer with:
1. Encapsulated embedding and position encoding in a single module
2. Support for linear patch embedding on pre-patchified inputs
3. Support for NaFlex functionality (NaViT + FlexiViT)
Based on:
- Original Vision Transformer: https://arxiv.org/abs/2010.11929
- FlexiViT: https://arxiv.org/abs/2212.08013
- NaViT: https://arxiv.org/abs/2307.06304
Copyright 2025
"""
import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Final, Any, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert
from timm.models._builder import build_model_with_cfg
from timm.models._features import feature_take_indices
from timm.models._features_fx import register_notrace_function, register_notrace_module
from timm.models._registry import register_model, generate_default_cfgs
from timm.models._manipulate import checkpoint_seq, named_apply
from .vision_transformer import Block, global_pool_nlc
_logger = logging.getLogger(__name__)
def batch_patchify(
x: torch.Tensor,
patch_size: Tuple[int, int],
pad: bool = True,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
B, C, H, W = x.shape
ph, pw = patch_size
# Ensure the image is divisible by patch size
if pad and (H % ph != 0 or W % pw != 0):
pad_h = (ph - H % ph) % ph
pad_w = (pw - W % pw) % pw
x = F.pad(x, (0, pad_w, 0, pad_h))
nh, nw = H // ph, W // pw
patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
# FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
return patches, (nh, nw)
@register_notrace_module
class FlexEmbeds(nn.Module):
""" Na(Flex) Embedding module for Vision Transformers
This module encapsulates the complete embedding process for Vision Transformers:
1. Patch embedding (via Conv2d or Linear)
2. Class and register token preparation
3. Position embedding addition
4. Pre-normalization (if requested)
5. Dropout application
Also supports NaFlex functionality (NaViT + FlexiViT):
- Variable aspect ratio and resolution via patch coordinates
- Patch type indicators for handling padding tokens in attention
- Flexible position embedding interpolation for arbitrary grid sizes
Note: Only supports non-overlapping position and register tokens
(i.e., position embeddings do not include class/register tokens)
The patch embedding can be one of two types:
1. Conv2d-based (default): For standard image inputs [B, C, H, W]
2. Linear-based: For pre-patchified inputs [B, N, P*P*C]
"""
def __init__(
self,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
embed_layer: Optional[str] = None, # 'conv' or 'linear', default is 'linear'
input_norm_layer: Optional[Type[nn.Module]] = None,
proj_norm_layer: Optional[Type[nn.Module]] = None,
final_norm_layer: Optional[Type[nn.Module]] = None,
pos_embed: str = 'learned',
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
class_token: bool = True,
reg_tokens: int = 0,
bias: bool = True,
dynamic_img_pad: bool = False,
pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14),
pos_embed_interp_mode: str = 'bicubic',
default_img_size: Union[int, Tuple[int, int]] = None,
):
super().__init__()
self.has_class_token = class_token
self.num_reg_tokens = reg_tokens
self.pos_embed_interp_mode = pos_embed_interp_mode
self.patch_size = to_2tuple(patch_size)
self.in_chans = in_chans
self.embed_dim = embed_dim
self.dynamic_img_pad = dynamic_img_pad
# Calculate number of prefix tokens
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
# Create class and register tokens
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
# Calculate grid size and number of patches
self.default_img_size: Optional[Tuple[int, int]] = None
self.pos_embed_grid_size: Optional[
Tuple[int, int]] = None # Stores the grid size used for learned pos embed init
if pos_embed_grid_size is None and default_img_size is not None:
self.default_img_size = to_2tuple(default_img_size)
self.pos_embed_grid_size = tuple([s // p for s, p in zip(self.default_img_size, self.patch_size)])
elif pos_embed_grid_size is not None:
# Use provided pos_embed_grid_size for NaFlex mode
self.pos_embed_grid_size = pos_embed_grid_size
# Determine patch embedding type (linear or conv2d)
if embed_layer == 'linear':
# Create linear projection for pre-patchified inputs
# Input dimension is patch_size^2 * in_chans
patch_dim = self.patch_size[0] * self.patch_size[1] * in_chans
self.norm_input = proj_norm_layer(patch_dim) if input_norm_layer else None
self.proj = nn.Linear(patch_dim, embed_dim, bias=bias)
self.flatten = False
self.is_linear = True
else:
# Default to convolutional patch embedding for image inputs
assert input_norm_layer is None
self.norm_input = None
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
)
self.flatten = True
self.is_linear = False
# Create normalization layer after the projection
self.norm_proj = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity()
# Create position embedding if needed - only for patches, never for prefix tokens
if not pos_embed or pos_embed == 'none':
self.pos_embed = None
self.pos_embed_type = 'none'
elif pos_embed == 'rope':
self.pos_embed = None
self.pos_embed_type = 'rope'
# Rotary embeddings will be computed on-the-fly in the forward pass
else:
# Store position embedding in (1, H, W, dim) format for easier resizing
if self.pos_embed_grid_size is not None:
h, w = self.pos_embed_grid_size
self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02)
self.pos_embed_type = 'learned'
else:
raise ValueError("Cannot initialize position embeddings without grid_size. "
"Please provide img_size or pos_embed_grid_size")
# Pre-normalization layer (separate from the main norm layer)
self.norm_final = final_norm_layer(embed_dim) if final_norm_layer is not None else nn.Identity()
# Dropout layers
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
from timm.layers.patch_dropout import PatchDropout
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
def feature_info(self, location):
"""Feature info utility method for feature extraction."""
return dict(num_chs=self.embed_dim, reduction=self.patch_size)
def feat_ratio(self, as_scalar=True):
"""Return the feature reduction ratio (stride)."""
if as_scalar:
return max(self.patch_size)
else:
return self.patch_size
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get grid (feature) size for given image size taking account of dynamic padding.
"""
if self.dynamic_img_pad:
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
"""Forward pass for combined embedding
Args:
x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C]
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex
Returns:
Embedded tensor with position encoding and class/register tokens applied
If patch_type is provided, also returns attention mask
"""
# Apply patch embedding
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
grid_size: Optional[List[int]] = None
B = x.shape[0]
if self.is_linear:
# Linear embedding path, works with NaFlex mode or standard 2D mode
if patch_coord is not None:
_assert(x.ndim == 3, 'Expecting patchified input with ndim == 3')
# Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
# Calculate the appropriate grid size from coords
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)]
else:
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
if self.norm_input is not None:
x = self.norm_input(x)
x = self.proj(x)
else:
assert x.ndim == 4, 'Convolutional input must be 4D'
if self.dynamic_img_pad:
H, W = x.shape[-2:]
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x)
grid_size = x.shape[-2:]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
# Apply normalization after flattening
x = self.norm_proj(x)
if self.pos_embed_type == 'learned':
if naflex_grid_sizes is not None:
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
else:
assert grid_size is not None
self._apply_learned_pos_embed(x, grid_size=grid_size)
elif self.pos_embed_type == 'rope':
assert False, "ROPE not yet implemented"
# Prepare and add class and register tokens
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(B, -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(B, -1, -1))
# Add tokens to the beginning
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
# Apply final pre-transformer normalization if specified
x = self.norm_final(x)
# Apply dropouts
x = self.patch_drop(self.pos_drop(x))
return x
#@torch.compiler.disable()
def _apply_learned_naflex_pos_embed(
self,
x: torch.Tensor,
naflex_grid_sizes: List[Tuple[int, int]],
):
# Handle each batch element separately with its own grid size
orig_h, orig_w = self.pos_embed.shape[1:3]
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
def _interp(_size):
if (_size[0] == orig_h) and (_size[1] == orig_w):
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else:
pos_embed_flat = F.interpolate(
pos_embed_nchw,
size=_size,
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
).flatten(2).transpose(1, 2)
return pos_embed_flat.to(dtype=x.dtype)
# FIXME leaving alternative code commented here for now for comparisons
# pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
# for i, s in enumerate(naflex_grid_sizes):
# if s in pos_embed_cache:
# pos_embed_flat = pos_embed_cache[s]
# else:
# pos_embed_flat = _interp(s)
# pos_embed_cache[s] = pos_embed_flat
#
# seq_len = min(x.shape[1], pos_embed_flat.shape[1])
# x[i, :seq_len] += pos_embed_flat[0, :seq_len]
# Determine unique grid sizes
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, k in enumerate(naflex_grid_sizes):
# k = h << 16 | w # FIXME can get jit compat with this
size_to_indices.setdefault(k, []).append(bi)
for k, batch_indices in size_to_indices.items():
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
# Interpolate only once for this (h, w)
pos_embed_flat = _interp(k)
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
x[:, :seq_len].index_add_(
0,
torch.as_tensor(batch_indices, device=x.device),
pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
)
def _apply_learned_pos_embed(
self,
x: torch.Tensor,
grid_size: List[int],
):
orig_h, orig_w = self.pos_embed.shape[1:3]
if grid_size[0] == orig_h or grid_size[1] == orig_w:
# No resize needed, just flatten
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else:
# Resize if needed - directly using F.interpolate
pos_embed_flat = F.interpolate(
self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W
size=grid_size,
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
).flatten(2).transpose(1, 2)
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
x.add_(pos_embed_flat)
@register_notrace_function
def create_attention_mask(
patch_valid: torch.Tensor,
num_prefix_tokens: int = 0,
symmetric: bool = True,
q_len: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> Optional[torch.Tensor]:
"""Creates an attention mask from patch validity information.
Supports two modes controlled by `symmetric`:
1. `symmetric=True` (default): Creates a symmetric mask of shape
[B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
both token i and token j are valid. Suitable for standard self-attention.
2. `symmetric=False`: Creates a potentially non-square mask of shape
[B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
the key/value token k is valid. Query token validity is not checked
in the mask itself. Useful for cross-attention or specific self-attention
implementations `q_len` can be specified.
Used for NaFlex mode to handle variable token counts and padding tokens.
Args:
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
to prepend, which are always considered valid.
symmetric: If True, create a symmetric mask.
If False, create an expanded mask based only on key/value validity.
q_len: Query sequence length override. Only used when `symmetric` is False.
Defaults to the key/value sequence length (`kv_len`) if None.
dtype: Dtype of the output attention mask (e.g., torch.float32).
Returns:
Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
or [B, 1, q_len, kv_len] if symmetric=False.
"""
if patch_valid is None:
return None
patch_valid = patch_valid.bool() # Ensure boolean type
B, N = patch_valid.shape
kv_len = N # Initial key/value length is the number of patches
# Prepend prefix tokens if any
if num_prefix_tokens > 0:
# Create prefix validity tensor on the same device/dtype base as patch_valid
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool)
# Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
kv_len += num_prefix_tokens # Update total key/value sequence length
if symmetric:
# Symmetric mask is True where BOTH query and key are valid
mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)
mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len]
else:
# Expanded mask
q_len = q_len or kv_len
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len)
# Create the float mask and apply masking using additive mask convention
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
# Fill with negative infinity where mask_bool is False (masked positions)
mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min)
return mask_float
@register_notrace_function
def global_pool_naflex(
x: torch.Tensor,
patch_valid: Optional[torch.Tensor] = None,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
):
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
# Fall back to standard pooling
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
return x
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
# Extract only the patch part of the mask (excluding prefix tokens)
if num_prefix_tokens > 0:
# Apply the mask to extract only valid tokens
x = x[:, num_prefix_tokens:] # prefix tokens not included in pooling
patch_valid_float = patch_valid.to(x.dtype)
if pool_type == 'avg':
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
pooled = masked_sums / valid_counts
return pooled
elif pool_type == 'avgmax':
# For avgmax, compute masked average and masked max
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
masked_avg = masked_sums / valid_counts
# For max pooling we set masked positions to large negative value
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
masked_max = masked_x.amax(dim=1)
# Combine average and max
return 0.5 * (masked_avg + masked_max)
elif pool_type == 'max':
# For max pooling we set masked positions to large negative value
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
return masked_x.amax(dim=1)
else:
assert False
class VisionTransformerFlex(nn.Module):
""" Vision Transformer (Na)Flex
A flexible implementation of Vision Transformer with:
1. Encapsulated embedding and position encoding in a single module
2. Support for linear patch embedding on pre-patchified inputs
3. Support for variable sequence length / aspect ratio images (NaFlex)
"""
def __init__(
self,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
proj_bias: bool = True,
init_values: Optional[float] = None,
class_token: bool = False,
reg_tokens: int = 0,
pos_embed: str = 'learn',
pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16),
pos_embed_interp_mode: str = 'bicubic',
default_img_size: Union[int, Tuple[int, int]] = 256,
dynamic_img_pad: bool = False,
pre_norm: bool = False,
final_norm: bool = True,
fc_norm: Optional[bool] = None,
num_classes: int = 1000,
global_pool: str = 'map',
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: str = '',
fix_init: bool = True,
embed_layer_type: str = 'linear',
embed_norm_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp,
) -> None:
"""
Args:
patch_size: Patch size.
in_chans: Number of image input channels.
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
reg_tokens: Number of register tokens.
pos_embed: Type of position embedding.
pos_embed_grid_size: Size of position embedding grid.
pos_embed_interp_mode: Interpolation mode for position embedding.
default_img_size: Input image size.
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
num_classes: Number of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
fix_init: Apply weight initialization fix (scaling w/ layer index).
embed_layer_type: Patch embedding implementation (e.g., 'linear', 'conv').
embed_norm_layer: Normalization layer to use / override in patch embed module.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
embed_norm_layer = get_norm_layer(embed_norm_layer)
act_layer = get_act_layer(act_layer) or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
self.num_reg_tokens = reg_tokens
self.has_class_token = class_token
self.grad_checkpointing = False
# Initialize embedding module (includes patch, position embedding, and class/reg tokens)
# VisionTransformerEmbeds is always used - handles both linear and conv embedding
self.embeds = FlexEmbeds(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
embed_layer=embed_layer_type,
proj_norm_layer=embed_norm_layer,
final_norm_layer=norm_layer if pre_norm else None,
pos_embed=pos_embed,
pos_embed_grid_size=pos_embed_grid_size,
pos_embed_interp_mode=pos_embed_interp_mode,
pos_drop_rate=pos_drop_rate,
patch_drop_rate=patch_drop_rate,
class_token=class_token,
reg_tokens=reg_tokens,
default_img_size=default_img_size,
dynamic_img_pad=dynamic_img_pad,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
# Transformer blocks
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
proj_bias=proj_bias,
init_values=init_values,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
mlp_layer=mlp_layer,
)
for i in range(depth)])
# Feature info for downstream tasks
patch_reduction = to_2tuple(patch_size)
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_reduction) for i in range(depth)]
self.norm = norm_layer(embed_dim) if final_norm and not fc_norm else nn.Identity()
# Classifier Head
if global_pool == 'map':
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
act_layer=act_layer,
)
else:
self.attn_pool = None
self.fc_norm = norm_layer(embed_dim) if final_norm and fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def init_weights(self, mode: str = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
named_apply(get_init_weights_vit(mode, head_bias), self)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
# Custom loading for the new model structure
from .vision_transformer import _load_weights as _orig_load_weights
def _load_weights_adapter(model, checkpoint_path, prefix=''):
"""Adapter function to handle the different model structure"""
state_dict = torch.load(checkpoint_path, map_location='cpu')
if isinstance(state_dict, dict) and 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
# Map original keys to new structure
for k in list(state_dict.keys()):
if k.startswith('cls_token'):
state_dict['embeds.' + k] = state_dict.pop(k)
elif k.startswith('reg_token'):
state_dict['embeds.' + k] = state_dict.pop(k)
elif k.startswith('pos_embed'):
state_dict['embeds.' + k] = state_dict.pop(k)
elif k.startswith('patch_embed'):
state_dict['embeds.' + k[12:]] = state_dict.pop(k)
return _orig_load_weights(model, state_dict, prefix)
_load_weights_adapter(self, checkpoint_path, prefix)
@torch.jit.ignore
def no_weight_decay(self) -> Set:
skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'}
return skip_list
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r'^embeds', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'):
self.embeds.patch_embed.set_grad_checkpointing(enable)
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
return self.head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map' and self.attn_pool is not None:
self.attn_pool = None # remove attention pooling
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
output_dict: bool = False,
patch_coord: Optional[torch.Tensor] = None,
patch_valid: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
mask: Optional attention mask
Returns:
A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
"""
# FIXME unfinished / untested
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# Create attention mask if patch_type is provided and mask is not
if mask is None and patch_valid is not None:
mask = create_attention_mask(patch_valid, self.num_prefix_tokens, x.dtype)
# Forward pass through embedding
x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid)
# Forward pass through blocks
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, attn_mask=mask)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# Process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
else:
prefix_tokens = None
if reshape:
# reshape to BCHW output format
grid_size = self.embeds.pos_embed_grid_size
if hasattr(self.embeds, 'dynamic_feat_size') and len(x.shape) >= 4:
_, height, width, _ = x.shape if len(x.shape) == 4 else (None, *x.shape[-3:-1], None)
H, W = self.embeds.dynamic_feat_size((height, width))
else:
H, W = grid_size
intermediates = [y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous()
for y in intermediates]
# For dictionary output
if output_dict:
result_dict = {}
# Intermediates are always included
result_dict['image_intermediates'] = intermediates
if prefix_tokens is not None and return_prefix_tokens:
result_dict['image_intermediates_prefix'] = prefix_tokens
# Only include features if not intermediates_only
if not intermediates_only:
x_final = self.norm(x)
result_dict['image_features'] = x_final
return result_dict
# For non-dictionary output, maintain the original behavior
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def forward_features(
self,
x: torch.Tensor,
patch_coord: Optional[torch.Tensor] = None,
patch_valid: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn_mask is None and patch_valid is not None:
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,
dtype=x.dtype
)
# Pass through embedding module with patch coordinate/type support
x = self.embeds(x, patch_coord=patch_coord)
# Apply transformer blocks with masked attention if mask provided
if attn_mask is not None:
# We need to apply blocks one by one with mask
for blk in self.blocks:
x = blk(x, attn_mask=attn_mask)
elif self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def _pool(
self,
x: torch.Tensor,
pool_type: Optional[str] = None,
patch_valid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.attn_pool is not None:
# For attention pooling, we need to pass the mask for NaFlex models
attn_mask = create_attention_mask(
patch_valid,
symmetric=False,
q_len=1,
dtype=x.dtype,
)
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
return x
pool_type = self.global_pool if pool_type is None else pool_type
x = global_pool_naflex(
x,
patch_valid,
pool_type=pool_type,
num_prefix_tokens=self.num_prefix_tokens,
)
return x
def forward_head(
self,
x: torch.Tensor,
pre_logits: bool = False,
patch_valid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = self._pool(x, patch_valid=patch_valid)
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(
self,
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
patch_coord: Optional[torch.Tensor] = None,
patch_valid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with optional NaFlex support.
Args:
x: Input tensor [B, C, H, W] or pre-patchified tensor [B, N, P*P*C] or NaFlex dict
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
Returns:
Model output tensor
"""
if isinstance(x, Dict):
# Handle dictionary input from NaFlex collator
patch_coord = x['patch_coord']
patch_valid = x['patch_valid']
patches = x['patches']
# DEBUG, reconstruct patches
# for i in range(len(patches)):
# patch = patches[i][patch_valid[i]]
# h = (patch_coord[i, :, 0].max() + 1).item()
# w = (patch_coord[i, :, 1].max() + 1).item()
# patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
# patch = patch.reshape(3, h*16, w*16)
# from torchvision.utils import save_image
# save_image(patch, f'patch_{i}.jpg', normalize=True)
else:
patches = x
# Create attention mask if patch_type is provided
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,
dtype=patches.dtype,
)
# Forward features with mask
x = self.forward_features(
patches,
patch_coord=patch_coord,
patch_valid=patch_valid,
attn_mask=attn_mask,
)
# Pass mask to forward_head for masked pooling
x = self.forward_head(
x,
patch_valid=patch_valid,
)
return x
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
"""Function imported from vision_transformer.py to maintain compatibility"""
from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm
if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode:
return init_weights_vit_moco
else:
return init_weights_vit_timm
def checkpoint_filter_fn(state_dict, model):
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
# Handle CombinedEmbed module pattern
out_dict = {}
for k, v in state_dict.items():
# Convert tokens and embeddings to combined_embed structure
if k == 'pos_embed':
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
num_cls_token = 0
num_reg_token = 0
if 'reg_token' in state_dict:
num_reg_token = state_dict['reg_token'].shape[1]
if 'cls_token' in state_dict:
num_cls_token = state_dict['cls_token'].shape[1]
num_prefix_tokens = num_cls_token + num_reg_token
# Original format is (1, N, C), need to reshape to (1, H, W, C)
num_patches = v.shape[1]
num_patches_no_prefix = num_patches - num_prefix_tokens
grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
grid_size = math.sqrt(num_patches)
if (grid_size_no_prefix != grid_size and (
grid_size_no_prefix.is_integer() and not grid_size.is_integer())):
# make a decision, did the pos_embed of the original include the prefix tokens?
num_patches = num_patches_no_prefix
cls_token_emb = v[:, 0:num_cls_token]
if cls_token_emb.numel():
state_dict['cls_token'] += cls_token_emb
reg_token_emb = v[:, num_cls_token:num_reg_token]
if reg_token_emb.numel():
state_dict['reg_token'] += reg_token_emb
v = v[:, num_prefix_tokens:]
grid_size = grid_size_no_prefix
grid_size = int(grid_size)
# Check if it's a perfect square for a standard grid
if grid_size * grid_size == num_patches:
# Reshape from (1, N, C) to (1, H, W, C)
v = v.reshape(1, grid_size, grid_size, v.shape[2])
else:
# Not a square grid, we need to get the actual dimensions
if hasattr(model.embeds.patch_embed, 'grid_size'):
h, w = model.embeds.patch_embed.grid_size
if h * w == num_patches:
# We have the right dimensions
v = v.reshape(1, h, w, v.shape[2])
else:
# Dimensions don't match, use interpolation
_logger.warning(
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
f"Using default initialization and will resize in forward pass."
)
# Keep v as is, the forward pass will handle resizing
out_dict['embeds.pos_embed'] = v
elif k == 'cls_token':
out_dict['embeds.cls_token'] = v
elif k == 'reg_token':
out_dict['embeds.reg_token'] = v
# Convert patch_embed.X to embeds.patch_embed.X
elif k.startswith('patch_embed.'):
suffix = k[12:]
if suffix == 'proj.weight':
# FIXME confirm patchify memory layout across use cases
v = v.permute(0, 2, 3, 1).flatten(1)
new_key = 'embeds.' + suffix
out_dict[new_key] = v
else:
out_dict[k] = v
return out_dict
def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 256, 256),
'pool_size': None,
'crop_pct': 0.95,
'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN,
'std': IMAGENET_INCEPTION_STD,
'first_conv': 'embeds.proj',
'classifier': 'head',
'license': 'apache-2.0',
**kwargs,
}
default_cfgs = generate_default_cfgs({
'vit_naflex_base_patch16_gap': _cfg(),
'vit_naflex_base_patch16_map': _cfg(),
'vit_naflex_so400m_patch16_map': _cfg(),
# sbb model testijg
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k',
input_size=(3, 256, 256), crop_pct=1.0),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
# traditional vit testing
'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
})
def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
VisionTransformerFlex, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs,
)
return model
@register_model
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
global_pool='map', reg_tokens=1)
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
This model supports:
1. Variable aspect ratios and resolutions via patch coordinates
2. Position embedding interpolation for arbitrary grid sizes
3. Explicit patch coordinates and valid token masking
"""
model_args = dict(
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True)
model = _create_vision_transformer_flex(
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs):
"""ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5,
global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh')
model = _create_vision_transformer_flex(
'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
return model