Add siglip2 compatible naflex encoders. Add support to factorized pos embeds and 'aspect preserving mode' to Flex Embeds. Some more docstrings and typing.

This commit is contained in:
Ross Wightman 2025-05-30 16:15:37 -07:00
parent b7ced7c40c
commit 72858c193c
2 changed files with 328 additions and 111 deletions

View File

@ -749,11 +749,9 @@ def patchify_image(
# Ensure the image is divisible by patch size # Ensure the image is divisible by patch size
if pad and (h % ph != 0 or w % pw != 0): if pad and (h % ph != 0 or w % pw != 0):
new_h = math.ceil(h / ph) * ph pad_h = (ph - h % ph) % ph # amount to add on bottom
new_w = math.ceil(w / pw) * pw pad_w = (pw - w % pw) % pw # amount to add on right
padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype) img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h))
padded_img[:, :h, :w] = img
img = padded_img
c, h, w = img.shape c, h, w = img.shape
# Calculate number of patches in each dimension # Calculate number of patches in each dimension

View File

@ -24,7 +24,16 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert from timm.layers import (
AttentionPoolLatent,
Mlp,
to_2tuple,
get_act_layer,
get_norm_layer,
LayerNorm,
LayerType,
_assert,
)
from timm.models._builder import build_model_with_cfg from timm.models._builder import build_model_with_cfg
from timm.models._features import feature_take_indices from timm.models._features import feature_take_indices
from timm.models._features_fx import register_notrace_function, register_notrace_module from timm.models._features_fx import register_notrace_function, register_notrace_module
@ -59,27 +68,46 @@ def batch_patchify(
@register_notrace_module @register_notrace_module
class FlexEmbeds(nn.Module): class FlexEmbeds(nn.Module):
""" Na(Flex) Embedding module for Vision Transformers """NaFlex Embedding module for Vision Transformers.
This module encapsulates the complete embedding process for Vision Transformers,
supporting both standard and NaFlex (NaViT + FlexiViT) functionality:
This module encapsulates the complete embedding process for Vision Transformers:
1. Patch embedding (via Conv2d or Linear) 1. Patch embedding (via Conv2d or Linear)
2. Class and register token preparation 2. Class and register token preparation
3. Position embedding addition 3. Position embedding addition with interpolation support
4. Pre-normalization (if requested) 4. Pre-normalization (if requested)
5. Dropout application 5. Dropout application
Also supports NaFlex functionality (NaViT + FlexiViT): NaFlex capabilities include:
- Variable aspect ratio and resolution via patch coordinates - Variable aspect ratio and resolution via patch coordinates
- Patch type indicators for handling padding tokens in attention - Patch type indicators for handling padding tokens in attention
- Flexible position embedding interpolation for arbitrary grid sizes - Flexible position embedding interpolation for arbitrary grid sizes
- Support for factorized position embeddings
Note: Only supports non-overlapping position and register tokens
(i.e., position embeddings do not include class/register tokens)
The patch embedding can be one of two types: The patch embedding can be one of two types:
1. Conv2d-based (default): For standard image inputs [B, C, H, W] - Conv2d-based (default): For standard image inputs [B, C, H, W]
2. Linear-based: For pre-patchified inputs [B, N, P*P*C] - Linear-based: For pre-patchified inputs [B, N, P*P*C]
Args:
patch_size: Size of patches for patch embedding
in_chans: Number of input image channels
embed_dim: Dimensionality of patch embedding
embed_layer: Type of embedding layer ('conv' or 'linear')
input_norm_layer: Normalization layer applied to input (linear mode only)
proj_norm_layer: Normalization layer applied after projection
final_norm_layer: Final normalization layer before output
pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none')
pos_drop_rate: Dropout rate for position embeddings
patch_drop_rate: Dropout rate for patch tokens
class_token: Whether to include a class token
reg_tokens: Number of register tokens to include
bias: Whether to use bias in projection layers
dynamic_img_pad: Whether to enable dynamic padding for variable resolution
pos_embed_grid_size: Grid size for position embedding initialization
pos_embed_interp_mode: Interpolation mode for position embedding resizing
pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
default_img_size: Default image size for position embedding grid calculation
""" """
def __init__( def __init__(
@ -100,12 +128,14 @@ class FlexEmbeds(nn.Module):
dynamic_img_pad: bool = False, dynamic_img_pad: bool = False,
pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14), pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14),
pos_embed_interp_mode: str = 'bicubic', pos_embed_interp_mode: str = 'bicubic',
pos_embed_ar_preserving: bool = False,
default_img_size: Union[int, Tuple[int, int]] = None, default_img_size: Union[int, Tuple[int, int]] = None,
): ):
super().__init__() super().__init__()
self.has_class_token = class_token self.has_class_token = class_token
self.num_reg_tokens = reg_tokens self.num_reg_tokens = reg_tokens
self.pos_embed_interp_mode = pos_embed_interp_mode self.pos_embed_interp_mode = pos_embed_interp_mode
self.pos_embed_ar_preserving = pos_embed_ar_preserving
self.patch_size = to_2tuple(patch_size) self.patch_size = to_2tuple(patch_size)
self.in_chans = in_chans self.in_chans = in_chans
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -153,22 +183,28 @@ class FlexEmbeds(nn.Module):
self.norm_proj = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity() self.norm_proj = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity()
# Create position embedding if needed - only for patches, never for prefix tokens # Create position embedding if needed - only for patches, never for prefix tokens
if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None:
raise ValueError(
"Cannot initialize position embeddings without grid_size."
"Please provide img_size or pos_embed_grid_size.")
self.pos_embed: Optional[torch.Tensor] = None
self.pos_embed_y: Optional[torch.Tensor] = None
self.pos_embed_x: Optional[torch.Tensor] = None
if not pos_embed or pos_embed == 'none': if not pos_embed or pos_embed == 'none':
self.pos_embed = None
self.pos_embed_type = 'none' self.pos_embed_type = 'none'
elif pos_embed == 'rope': elif pos_embed == 'rope':
self.pos_embed = None
self.pos_embed_type = 'rope' self.pos_embed_type = 'rope'
# Rotary embeddings will be computed on-the-fly in the forward pass # Rotary embeddings will be computed on-the-fly in the forward pass
elif pos_embed == 'factorized':
h, w = self.pos_embed_grid_size
self.pos_embed_type = 'factorized'
self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim) * .02)
self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim) * .02)
else: else:
# Store position embedding in (1, H, W, dim) format for easier resizing # Store position embedding in (1, H, W, dim) format for easier resizing
if self.pos_embed_grid_size is not None:
h, w = self.pos_embed_grid_size h, w = self.pos_embed_grid_size
self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02) self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02)
self.pos_embed_type = 'learned' self.pos_embed_type = 'learned'
else:
raise ValueError("Cannot initialize position embeddings without grid_size. "
"Please provide img_size or pos_embed_grid_size")
# Pre-normalization layer (separate from the main norm layer) # Pre-normalization layer (separate from the main norm layer)
self.norm_final = final_norm_layer(embed_dim) if final_norm_layer is not None else nn.Identity() self.norm_final = final_norm_layer(embed_dim) if final_norm_layer is not None else nn.Identity()
@ -184,35 +220,57 @@ class FlexEmbeds(nn.Module):
else: else:
self.patch_drop = nn.Identity() self.patch_drop = nn.Identity()
def feature_info(self, location): def feature_info(self, location) -> Dict[str, Any]:
"""Feature info utility method for feature extraction.""" """Get feature information for feature extraction.
Args:
location: Feature extraction location identifier
Returns:
Dictionary containing feature channel count and reduction factor
"""
return dict(num_chs=self.embed_dim, reduction=self.patch_size) return dict(num_chs=self.embed_dim, reduction=self.patch_size)
def feat_ratio(self, as_scalar=True): def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]:
"""Return the feature reduction ratio (stride).""" """Get the feature reduction ratio (stride) of the patch embedding.
Args:
as_scalar: Whether to return the maximum dimension as a scalar
Returns:
Feature reduction ratio as scalar or tuple
"""
if as_scalar: if as_scalar:
return max(self.patch_size) return max(self.patch_size)
else: else:
return self.patch_size return self.patch_size
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get grid (feature) size for given image size taking account of dynamic padding. """Calculate grid (feature) size for given image size.
Takes into account dynamic padding when enabled.
Args:
img_size: Input image size as (height, width)
Returns:
Grid size as (grid_height, grid_width)
""" """
if self.dynamic_img_pad: if self.dynamic_img_pad:
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
else: else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward pass for combined embedding """Forward pass for patch embedding with position encoding.
Args: Args:
x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C] x: Input tensor [B, C, H, W] for conv mode or [B, N, P*P*C] for pre-patchified linear mode
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
Returns: Returns:
Embedded tensor with position encoding and class/register tokens applied Embedded tensor with position encoding and class/register tokens applied.
If patch_type is provided, also returns attention mask Shape: [B, num_prefix_tokens + N, embed_dim]
""" """
# Apply patch embedding # Apply patch embedding
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
@ -259,6 +317,9 @@ class FlexEmbeds(nn.Module):
else: else:
assert grid_size is not None assert grid_size is not None
self._apply_learned_pos_embed(x, grid_size=grid_size) self._apply_learned_pos_embed(x, grid_size=grid_size)
elif self.pos_embed_type == 'factorized':
if naflex_grid_sizes is not None:
self._apply_factorized_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
elif self.pos_embed_type == 'rope': elif self.pos_embed_type == 'rope':
assert False, "ROPE not yet implemented" assert False, "ROPE not yet implemented"
@ -285,22 +346,38 @@ class FlexEmbeds(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
naflex_grid_sizes: List[Tuple[int, int]], naflex_grid_sizes: List[Tuple[int, int]],
): ) -> None:
"""Apply learned position embeddings to NaFlex batch in-place.
Interpolates learned position embeddings for each sample in the batch
based on their individual grid sizes.
Args:
x: Input tensor to add position embeddings to
naflex_grid_sizes: List of (height, width) grid sizes for each batch element
"""
# Handle each batch element separately with its own grid size # Handle each batch element separately with its own grid size
orig_h, orig_w = self.pos_embed.shape[1:3] orig_h, orig_w = self.pos_embed.shape[1:3]
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
def _interp(_size): def _interp2d(size):
if (_size[0] == orig_h) and (_size[1] == orig_w): """
Return a flattened positional-embedding grid at an arbitrary spatial resolution.
Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into
a (1, H*W, C) sequence that matches the requested size.
"""
if (size[0] == orig_h) and (size[1] == orig_w):
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else: else:
_interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size
pos_embed_flat = F.interpolate( pos_embed_flat = F.interpolate(
pos_embed_nchw, pos_embed_nchw,
size=_size, size=_interp_size,
mode=self.pos_embed_interp_mode, mode=self.pos_embed_interp_mode,
align_corners=False, align_corners=False,
antialias=True, antialias=True,
).flatten(2).transpose(1, 2) )[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2)
return pos_embed_flat.to(dtype=x.dtype) return pos_embed_flat.to(dtype=x.dtype)
# FIXME leaving alternative code commented here for now for comparisons # FIXME leaving alternative code commented here for now for comparisons
@ -315,7 +392,7 @@ class FlexEmbeds(nn.Module):
# seq_len = min(x.shape[1], pos_embed_flat.shape[1]) # seq_len = min(x.shape[1], pos_embed_flat.shape[1])
# x[i, :seq_len] += pos_embed_flat[0, :seq_len] # x[i, :seq_len] += pos_embed_flat[0, :seq_len]
# Determine unique grid sizes # Determine unique grid sizes to avoid duplicate interpolation
size_to_indices: Dict[Tuple[int, int], List[int]] = {} size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, k in enumerate(naflex_grid_sizes): for bi, k in enumerate(naflex_grid_sizes):
# k = h << 16 | w # FIXME can get jit compat with this # k = h << 16 | w # FIXME can get jit compat with this
@ -324,7 +401,7 @@ class FlexEmbeds(nn.Module):
for k, batch_indices in size_to_indices.items(): for k, batch_indices in size_to_indices.items():
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
# Interpolate only once for this (h, w) # Interpolate only once for this (h, w)
pos_embed_flat = _interp(k) pos_embed_flat = _interp2d(k)
seq_len = min(x.shape[1], pos_embed_flat.shape[1]) seq_len = min(x.shape[1], pos_embed_flat.shape[1])
x[:, :seq_len].index_add_( x[:, :seq_len].index_add_(
0, 0,
@ -336,7 +413,15 @@ class FlexEmbeds(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
grid_size: List[int], grid_size: List[int],
): ) -> None:
"""Apply learned position embeddings to standard batch in-place.
Interpolates learned position embeddings to match the specified grid size.
Args:
x: Input tensor to add position embeddings to
grid_size: Target grid size as [height, width]
"""
orig_h, orig_w = self.pos_embed.shape[1:3] orig_h, orig_w = self.pos_embed.shape[1:3]
if grid_size[0] == orig_h or grid_size[1] == orig_w: if grid_size[0] == orig_h or grid_size[1] == orig_w:
# No resize needed, just flatten # No resize needed, just flatten
@ -354,6 +439,63 @@ class FlexEmbeds(nn.Module):
x.add_(pos_embed_flat) x.add_(pos_embed_flat)
def _apply_factorized_naflex_pos_embed(
self,
x: torch.Tensor,
naflex_grid_sizes: List[Tuple[int, int]],
) -> None:
"""Apply factorized position embeddings to NaFlex batch in-place.
Uses separate Y and X position embedding tables that are interpolated
and combined for each sample's grid size.
Args:
x: Input tensor to add position embeddings to
naflex_grid_sizes: List of (height, width) grid sizes for each batch element
"""
assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample
# Handle each batch element separately with its own grid size
orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
# bucket samples that share the same (H,W) so we build each grid once
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, k in enumerate(naflex_grid_sizes):
size_to_indices.setdefault(k, []).append(bi)
def _interp1d(table: torch.Tensor, length: int) -> torch.Tensor:
"""
Resample a 1-D positional-embedding table to specified length
and return it in (1, L, C) layout, dtype matching x.
"""
return F.interpolate(
table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out)
size=length,
mode='linear',
align_corners=False,
).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C)
for k, batch_indices in size_to_indices.items():
target_h, target_w = k
if self.pos_embed_ar_preserving:
len_y = len_x = max(target_h, target_w)
else:
len_y, len_x = target_h, target_w
pe_y = _interp1d(self.pos_embed_y, len_y)[:, :target_h] # (1,H,C)
pe_x = _interp1d(self.pos_embed_x, len_x)[:, :target_w] # (1,W,C)
# Broadcast, add and flatten to sequence layout (row major)
pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C)
pos = pos.flatten(1, 2)
seq_len = min(x.shape[1], pos.shape[1])
x[:, :seq_len].index_add_(
0,
torch.as_tensor(batch_indices, device=x.device),
pos[:, :seq_len].expand(len(batch_indices), -1, -1)
)
@register_notrace_function @register_notrace_function
def create_attention_mask( def create_attention_mask(
@ -430,7 +572,21 @@ def global_pool_naflex(
patch_valid: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None,
pool_type: str = 'token', pool_type: str = 'token',
num_prefix_tokens: int = 1, num_prefix_tokens: int = 1,
): ) -> torch.Tensor:
"""Global pooling with NaFlex support for masked tokens.
Applies global pooling while respecting patch validity masks to exclude
padding tokens from pooling operations.
Args:
x: Input tensor with shape [B, N, C]
patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens]
pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max')
num_prefix_tokens: Number of prefix tokens (class/register) to exclude from masking
Returns:
Pooled tensor with shape [B, C]
"""
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'): if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
# Fall back to standard pooling # Fall back to standard pooling
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens) x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
@ -472,12 +628,16 @@ def global_pool_naflex(
class VisionTransformerFlex(nn.Module): class VisionTransformerFlex(nn.Module):
""" Vision Transformer (Na)Flex """Vision Transformer with NaFlex support for flexible input handling.
A flexible implementation of Vision Transformer with: A flexible implementation of Vision Transformer that supports:
1. Encapsulated embedding and position encoding in a single module - Standard image classification with various pooling strategies
2. Support for linear patch embedding on pre-patchified inputs - NaFlex functionality for variable aspect ratios and resolutions
3. Support for variable sequence length / aspect ratio images (NaFlex) - Linear patch embedding for pre-patchified inputs
- Multiple position embedding strategies (learned, factorized, rope)
- Comprehensive attention masking for efficient batch processing
- Encapsulated embedding and position encoding in FlexEmbeds module
- Compatible with standard ViT checkpoints through checkpoint filtering
""" """
def __init__( def __init__(
@ -494,9 +654,10 @@ class VisionTransformerFlex(nn.Module):
init_values: Optional[float] = None, init_values: Optional[float] = None,
class_token: bool = False, class_token: bool = False,
reg_tokens: int = 0, reg_tokens: int = 0,
pos_embed: str = 'learn', pos_embed: str = 'learned',
pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16), pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16),
pos_embed_interp_mode: str = 'bicubic', pos_embed_interp_mode: str = 'bicubic',
pos_embed_ar_preserving: bool = False,
default_img_size: Union[int, Tuple[int, int]] = 256, default_img_size: Union[int, Tuple[int, int]] = 256,
dynamic_img_pad: bool = False, dynamic_img_pad: bool = False,
pre_norm: bool = False, pre_norm: bool = False,
@ -519,44 +680,52 @@ class VisionTransformerFlex(nn.Module):
block_fn: Type[nn.Module] = Block, block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp, mlp_layer: Type[nn.Module] = Mlp,
) -> None: ) -> None:
""" """Initialize VisionTransformerFlex model.
Args: Args:
patch_size: Patch size. patch_size: Size of patches for patch embedding
in_chans: Number of image input channels. in_chans: Number of input image channels
embed_dim: Transformer embedding dimension. embed_dim: Transformer embedding dimension
depth: Depth of transformer. depth: Number of transformer blocks
num_heads: Number of attention heads. num_heads: Number of attention heads
mlp_ratio: Ratio of mlp hidden dim to embedding dim. mlp_ratio: Ratio of MLP hidden dimension to embedding dimension
qkv_bias: Enable bias for qkv projections if True. qkv_bias: Whether to use bias in query, key, value projections
init_values: Layer-scale init values (layer-scale enabled if not None). qk_norm: Whether to normalize query and key projections
class_token: Use class token. proj_bias: Whether to use bias in linear projections
reg_tokens: Number of register tokens. init_values: Layer-scale init values (layer-scale enabled if not None)
pos_embed: Type of position embedding. class_token: Whether to include a class token
pos_embed_grid_size: Size of position embedding grid. reg_tokens: Number of register tokens to include
pos_embed_interp_mode: Interpolation mode for position embedding. pos_embed: Type of position embedding
default_img_size: Input image size. pos_embed_grid_size: Grid size for position embedding initialization
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT). pos_embed_interp_mode: Interpolation mode for position embedding resizing
final_norm: Enable norm after transformer blocks, before head (standard in most ViT). pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'. default_img_size: Default image size for position embedding grid calculation
num_classes: Number of classes for classification head. dynamic_img_pad: Whether to enable dynamic padding for variable resolution
global_pool: Type of global pooling for final sequence (default: 'token'). pre_norm: Whether to apply normalization before attention/MLP layers
drop_rate: Head dropout rate. final_norm: Whether to apply final normalization before classifier
pos_drop_rate: Position embedding dropout rate. fc_norm: Whether to normalize features before final classifier
attn_drop_rate: Attention dropout rate. num_classes: Number of classification classes
drop_path_rate: Stochastic depth rate. global_pool: Type of global pooling for final sequence
weight_init: Weight initialization scheme. drop_rate: Dropout rate for classifier
fix_init: Apply weight initialization fix (scaling w/ layer index). pos_drop_rate: Dropout rate for position embeddings
embed_layer_type: Patch embedding implementation (e.g., 'linear', 'conv'). patch_drop_rate: Dropout rate for patch tokens
embed_norm_layer: Normalization layer to use / override in patch embed module. proj_drop_rate: Dropout rate for linear projections
norm_layer: Normalization layer. attn_drop_rate: Dropout rate for attention weights
act_layer: MLP activation layer. drop_path_rate: Stochastic depth drop rate
block_fn: Transformer block layer. weight_init: Weight initialization scheme
fix_init: Apply weight initialization fix (scaling w/ layer index)
embed_layer_type: Type of embedding layer ('conv' or 'linear')
embed_norm_layer: Normalization layer for embeddings
norm_layer: Normalization layer for transformer blocks
act_layer: Activation layer for MLP blocks
block_fn: Transformer block implementation class
mlp_layer: MLP implementation class
""" """
super().__init__() super().__init__()
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert class_token or global_pool != 'token' assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn') assert pos_embed in ('', 'none', 'learned', 'factorized')
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) norm_layer = get_norm_layer(norm_layer) or LayerNorm
embed_norm_layer = get_norm_layer(embed_norm_layer) embed_norm_layer = get_norm_layer(embed_norm_layer)
act_layer = get_act_layer(act_layer) or nn.GELU act_layer = get_act_layer(act_layer) or nn.GELU
@ -581,6 +750,7 @@ class VisionTransformerFlex(nn.Module):
pos_embed=pos_embed, pos_embed=pos_embed,
pos_embed_grid_size=pos_embed_grid_size, pos_embed_grid_size=pos_embed_grid_size,
pos_embed_interp_mode=pos_embed_interp_mode, pos_embed_interp_mode=pos_embed_interp_mode,
pos_embed_ar_preserving=pos_embed_ar_preserving,
pos_drop_rate=pos_drop_rate, pos_drop_rate=pos_drop_rate,
patch_drop_rate=patch_drop_rate, patch_drop_rate=patch_drop_rate,
class_token=class_token, class_token=class_token,
@ -638,8 +808,9 @@ class VisionTransformerFlex(nn.Module):
if fix_init: if fix_init:
self.fix_init_weight() self.fix_init_weight()
def fix_init_weight(self): def fix_init_weight(self) -> None:
def rescale(param, _layer_id): """Apply initialization weight fix with layer-wise scaling."""
def rescale(param: torch.Tensor, _layer_id: int) -> None:
param.div_(math.sqrt(2.0 * _layer_id)) param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks): for layer_id, layer in enumerate(self.blocks):
@ -647,6 +818,11 @@ class VisionTransformerFlex(nn.Module):
rescale(layer.mlp.fc2.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def init_weights(self, mode: str = '') -> None: def init_weights(self, mode: str = '') -> None:
"""Initialize model weights according to specified scheme.
Args:
mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '')
"""
assert mode in ('jax', 'jax_nlhb', 'moco', '') assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
named_apply(get_init_weights_vit(mode, head_bias), self) named_apply(get_init_weights_vit(mode, head_bias), self)
@ -679,11 +855,24 @@ class VisionTransformerFlex(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self) -> Set: def no_weight_decay(self) -> Set:
"""Get set of parameter names that should not have weight decay applied.
Returns:
Set of parameter names to skip during weight decay
"""
skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'} skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'}
return skip_list return skip_list
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict: def group_matcher(self, coarse: bool = False) -> Dict:
"""Get parameter group matcher for optimizer parameter grouping.
Args:
coarse: Whether to use coarse-grained grouping
Returns:
Dictionary mapping group names to regex patterns
"""
return dict( return dict(
stem=r'^embeds', # stem and embed stem=r'^embeds', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
@ -691,15 +880,31 @@ class VisionTransformerFlex(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None: def set_grad_checkpointing(self, enable: bool = True) -> None:
"""Enable or disable gradient checkpointing for memory efficiency.
Args:
enable: Whether to enable gradient checkpointing
"""
self.grad_checkpointing = enable self.grad_checkpointing = enable
if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'): if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'):
self.embeds.patch_embed.set_grad_checkpointing(enable) self.embeds.patch_embed.set_grad_checkpointing(enable)
@torch.jit.ignore @torch.jit.ignore
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
"""Get the classification head module.
Returns:
Classification head module
"""
return self.head return self.head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Reset the classification head with new number of classes and pooling.
Args:
num_classes: Number of classes for new classification head
global_pool: Optional new global pooling type
"""
self.num_classes = num_classes self.num_classes = num_classes
if global_pool is not None: if global_pool is not None:
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
@ -953,7 +1158,7 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
return init_weights_vit_timm return init_weights_vit_timm
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict: Dict[str, Any], model: VisionTransformerFlex) -> Dict[str, Any]:
"""Handle state dict conversion from original ViT to the new version with combined embedding.""" """Handle state dict conversion from original ViT to the new version with combined embedding."""
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
@ -1049,7 +1254,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
'vit_naflex_base_patch16_gap': _cfg(), 'vit_naflex_base_patch16_gap': _cfg(),
'vit_naflex_base_patch16_map': _cfg(), 'vit_naflex_base_patch16_map': _cfg(),
'vit_naflex_so400m_patch16_map': _cfg(),
'vit_naflex_base_patch16_siglip': _cfg(),
'vit_naflex_so400m_patch16_siglip': _cfg(),
# sbb model testijg # sbb model testijg
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg( 'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
@ -1073,7 +1280,7 @@ default_cfgs = generate_default_cfgs({
}) })
def _create_vision_transformer_flex(variant, pretrained=False, **kwargs): def _create_vision_transformer_flex(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
model = build_model_with_cfg( model = build_model_with_cfg(
VisionTransformerFlex, variant, pretrained, VisionTransformerFlex, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
@ -1083,19 +1290,19 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
@register_model @register_model
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs): def vit_naflex_base_patch16_gap(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
""" """
model_args = dict( model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs) global_pool='avg', reg_tokens=4, fc_norm=True, **kwargs)
model = _create_vision_transformer_flex( model = _create_vision_transformer_flex(
'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vit_naflex_base_patch16_map(pretrained=False, **kwargs): def vit_naflex_base_patch16_map(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
""" """
model_args = dict( model_args = dict(
@ -1107,7 +1314,7 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
@register_model @register_model
def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs): def vit_naflex_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
This model supports: This model supports:
@ -1117,31 +1324,43 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True) qkv_bias=False, reg_tokens=1, global_pool='avg', fc_norm=True)
model = _create_vision_transformer_flex( model = _create_vision_transformer_flex(
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs): def vit_naflex_base_patch16(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_args = dict( model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, patch_size=16, embed_dim=768, depth=12, num_heads=12,
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14)) global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) model = _create_vision_transformer_flex(
'vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs): def vit_naflex_base_patch16_siglip(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, act_layer='gelu_tanh', global_pool='map')
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16_siglip', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions. """ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
""" """
model_args = dict( model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5, patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh') act_layer='gelu_tanh', global_pool='map')
model = _create_vision_transformer_flex( model = _create_vision_transformer_flex(
'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_naflex_so400m_patch16_siglip', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model