Add siglip2 compatible naflex encoders. Add support to factorized pos embeds and 'aspect preserving mode' to Flex Embeds. Some more docstrings and typing.

This commit is contained in:
Ross Wightman 2025-05-30 16:15:37 -07:00
parent b7ced7c40c
commit 72858c193c
2 changed files with 328 additions and 111 deletions

View File

@ -749,11 +749,9 @@ def patchify_image(
# Ensure the image is divisible by patch size
if pad and (h % ph != 0 or w % pw != 0):
new_h = math.ceil(h / ph) * ph
new_w = math.ceil(w / pw) * pw
padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype)
padded_img[:, :h, :w] = img
img = padded_img
pad_h = (ph - h % ph) % ph # amount to add on bottom
pad_w = (pw - w % pw) % pw # amount to add on right
img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h))
c, h, w = img.shape
# Calculate number of patches in each dimension

View File

@ -24,7 +24,16 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert
from timm.layers import (
AttentionPoolLatent,
Mlp,
to_2tuple,
get_act_layer,
get_norm_layer,
LayerNorm,
LayerType,
_assert,
)
from timm.models._builder import build_model_with_cfg
from timm.models._features import feature_take_indices
from timm.models._features_fx import register_notrace_function, register_notrace_module
@ -59,27 +68,46 @@ def batch_patchify(
@register_notrace_module
class FlexEmbeds(nn.Module):
""" Na(Flex) Embedding module for Vision Transformers
"""NaFlex Embedding module for Vision Transformers.
This module encapsulates the complete embedding process for Vision Transformers,
supporting both standard and NaFlex (NaViT + FlexiViT) functionality:
This module encapsulates the complete embedding process for Vision Transformers:
1. Patch embedding (via Conv2d or Linear)
2. Class and register token preparation
3. Position embedding addition
3. Position embedding addition with interpolation support
4. Pre-normalization (if requested)
5. Dropout application
Also supports NaFlex functionality (NaViT + FlexiViT):
NaFlex capabilities include:
- Variable aspect ratio and resolution via patch coordinates
- Patch type indicators for handling padding tokens in attention
- Flexible position embedding interpolation for arbitrary grid sizes
Note: Only supports non-overlapping position and register tokens
(i.e., position embeddings do not include class/register tokens)
- Support for factorized position embeddings
The patch embedding can be one of two types:
1. Conv2d-based (default): For standard image inputs [B, C, H, W]
2. Linear-based: For pre-patchified inputs [B, N, P*P*C]
- Conv2d-based (default): For standard image inputs [B, C, H, W]
- Linear-based: For pre-patchified inputs [B, N, P*P*C]
Args:
patch_size: Size of patches for patch embedding
in_chans: Number of input image channels
embed_dim: Dimensionality of patch embedding
embed_layer: Type of embedding layer ('conv' or 'linear')
input_norm_layer: Normalization layer applied to input (linear mode only)
proj_norm_layer: Normalization layer applied after projection
final_norm_layer: Final normalization layer before output
pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none')
pos_drop_rate: Dropout rate for position embeddings
patch_drop_rate: Dropout rate for patch tokens
class_token: Whether to include a class token
reg_tokens: Number of register tokens to include
bias: Whether to use bias in projection layers
dynamic_img_pad: Whether to enable dynamic padding for variable resolution
pos_embed_grid_size: Grid size for position embedding initialization
pos_embed_interp_mode: Interpolation mode for position embedding resizing
pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
default_img_size: Default image size for position embedding grid calculation
"""
def __init__(
@ -100,12 +128,14 @@ class FlexEmbeds(nn.Module):
dynamic_img_pad: bool = False,
pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14),
pos_embed_interp_mode: str = 'bicubic',
pos_embed_ar_preserving: bool = False,
default_img_size: Union[int, Tuple[int, int]] = None,
):
super().__init__()
self.has_class_token = class_token
self.num_reg_tokens = reg_tokens
self.pos_embed_interp_mode = pos_embed_interp_mode
self.pos_embed_ar_preserving = pos_embed_ar_preserving
self.patch_size = to_2tuple(patch_size)
self.in_chans = in_chans
self.embed_dim = embed_dim
@ -153,22 +183,28 @@ class FlexEmbeds(nn.Module):
self.norm_proj = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity()
# Create position embedding if needed - only for patches, never for prefix tokens
if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None:
raise ValueError(
"Cannot initialize position embeddings without grid_size."
"Please provide img_size or pos_embed_grid_size.")
self.pos_embed: Optional[torch.Tensor] = None
self.pos_embed_y: Optional[torch.Tensor] = None
self.pos_embed_x: Optional[torch.Tensor] = None
if not pos_embed or pos_embed == 'none':
self.pos_embed = None
self.pos_embed_type = 'none'
elif pos_embed == 'rope':
self.pos_embed = None
self.pos_embed_type = 'rope'
# Rotary embeddings will be computed on-the-fly in the forward pass
elif pos_embed == 'factorized':
h, w = self.pos_embed_grid_size
self.pos_embed_type = 'factorized'
self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim) * .02)
self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim) * .02)
else:
# Store position embedding in (1, H, W, dim) format for easier resizing
if self.pos_embed_grid_size is not None:
h, w = self.pos_embed_grid_size
self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02)
self.pos_embed_type = 'learned'
else:
raise ValueError("Cannot initialize position embeddings without grid_size. "
"Please provide img_size or pos_embed_grid_size")
h, w = self.pos_embed_grid_size
self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02)
self.pos_embed_type = 'learned'
# Pre-normalization layer (separate from the main norm layer)
self.norm_final = final_norm_layer(embed_dim) if final_norm_layer is not None else nn.Identity()
@ -184,35 +220,57 @@ class FlexEmbeds(nn.Module):
else:
self.patch_drop = nn.Identity()
def feature_info(self, location):
"""Feature info utility method for feature extraction."""
def feature_info(self, location) -> Dict[str, Any]:
"""Get feature information for feature extraction.
Args:
location: Feature extraction location identifier
Returns:
Dictionary containing feature channel count and reduction factor
"""
return dict(num_chs=self.embed_dim, reduction=self.patch_size)
def feat_ratio(self, as_scalar=True):
"""Return the feature reduction ratio (stride)."""
def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]:
"""Get the feature reduction ratio (stride) of the patch embedding.
Args:
as_scalar: Whether to return the maximum dimension as a scalar
Returns:
Feature reduction ratio as scalar or tuple
"""
if as_scalar:
return max(self.patch_size)
else:
return self.patch_size
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get grid (feature) size for given image size taking account of dynamic padding.
"""Calculate grid (feature) size for given image size.
Takes into account dynamic padding when enabled.
Args:
img_size: Input image size as (height, width)
Returns:
Grid size as (grid_height, grid_width)
"""
if self.dynamic_img_pad:
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
"""Forward pass for combined embedding
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward pass for patch embedding with position encoding.
Args:
x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C]
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex
x: Input tensor [B, C, H, W] for conv mode or [B, N, P*P*C] for pre-patchified linear mode
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
Returns:
Embedded tensor with position encoding and class/register tokens applied
If patch_type is provided, also returns attention mask
Embedded tensor with position encoding and class/register tokens applied.
Shape: [B, num_prefix_tokens + N, embed_dim]
"""
# Apply patch embedding
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
@ -259,6 +317,9 @@ class FlexEmbeds(nn.Module):
else:
assert grid_size is not None
self._apply_learned_pos_embed(x, grid_size=grid_size)
elif self.pos_embed_type == 'factorized':
if naflex_grid_sizes is not None:
self._apply_factorized_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
elif self.pos_embed_type == 'rope':
assert False, "ROPE not yet implemented"
@ -285,22 +346,38 @@ class FlexEmbeds(nn.Module):
self,
x: torch.Tensor,
naflex_grid_sizes: List[Tuple[int, int]],
):
) -> None:
"""Apply learned position embeddings to NaFlex batch in-place.
Interpolates learned position embeddings for each sample in the batch
based on their individual grid sizes.
Args:
x: Input tensor to add position embeddings to
naflex_grid_sizes: List of (height, width) grid sizes for each batch element
"""
# Handle each batch element separately with its own grid size
orig_h, orig_w = self.pos_embed.shape[1:3]
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
def _interp(_size):
if (_size[0] == orig_h) and (_size[1] == orig_w):
def _interp2d(size):
"""
Return a flattened positional-embedding grid at an arbitrary spatial resolution.
Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into
a (1, H*W, C) sequence that matches the requested size.
"""
if (size[0] == orig_h) and (size[1] == orig_w):
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else:
_interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size
pos_embed_flat = F.interpolate(
pos_embed_nchw,
size=_size,
size=_interp_size,
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
).flatten(2).transpose(1, 2)
)[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2)
return pos_embed_flat.to(dtype=x.dtype)
# FIXME leaving alternative code commented here for now for comparisons
@ -315,7 +392,7 @@ class FlexEmbeds(nn.Module):
# seq_len = min(x.shape[1], pos_embed_flat.shape[1])
# x[i, :seq_len] += pos_embed_flat[0, :seq_len]
# Determine unique grid sizes
# Determine unique grid sizes to avoid duplicate interpolation
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, k in enumerate(naflex_grid_sizes):
# k = h << 16 | w # FIXME can get jit compat with this
@ -324,7 +401,7 @@ class FlexEmbeds(nn.Module):
for k, batch_indices in size_to_indices.items():
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
# Interpolate only once for this (h, w)
pos_embed_flat = _interp(k)
pos_embed_flat = _interp2d(k)
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
x[:, :seq_len].index_add_(
0,
@ -336,7 +413,15 @@ class FlexEmbeds(nn.Module):
self,
x: torch.Tensor,
grid_size: List[int],
):
) -> None:
"""Apply learned position embeddings to standard batch in-place.
Interpolates learned position embeddings to match the specified grid size.
Args:
x: Input tensor to add position embeddings to
grid_size: Target grid size as [height, width]
"""
orig_h, orig_w = self.pos_embed.shape[1:3]
if grid_size[0] == orig_h or grid_size[1] == orig_w:
# No resize needed, just flatten
@ -354,6 +439,63 @@ class FlexEmbeds(nn.Module):
x.add_(pos_embed_flat)
def _apply_factorized_naflex_pos_embed(
self,
x: torch.Tensor,
naflex_grid_sizes: List[Tuple[int, int]],
) -> None:
"""Apply factorized position embeddings to NaFlex batch in-place.
Uses separate Y and X position embedding tables that are interpolated
and combined for each sample's grid size.
Args:
x: Input tensor to add position embeddings to
naflex_grid_sizes: List of (height, width) grid sizes for each batch element
"""
assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample
# Handle each batch element separately with its own grid size
orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
# bucket samples that share the same (H,W) so we build each grid once
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, k in enumerate(naflex_grid_sizes):
size_to_indices.setdefault(k, []).append(bi)
def _interp1d(table: torch.Tensor, length: int) -> torch.Tensor:
"""
Resample a 1-D positional-embedding table to specified length
and return it in (1, L, C) layout, dtype matching x.
"""
return F.interpolate(
table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out)
size=length,
mode='linear',
align_corners=False,
).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C)
for k, batch_indices in size_to_indices.items():
target_h, target_w = k
if self.pos_embed_ar_preserving:
len_y = len_x = max(target_h, target_w)
else:
len_y, len_x = target_h, target_w
pe_y = _interp1d(self.pos_embed_y, len_y)[:, :target_h] # (1,H,C)
pe_x = _interp1d(self.pos_embed_x, len_x)[:, :target_w] # (1,W,C)
# Broadcast, add and flatten to sequence layout (row major)
pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C)
pos = pos.flatten(1, 2)
seq_len = min(x.shape[1], pos.shape[1])
x[:, :seq_len].index_add_(
0,
torch.as_tensor(batch_indices, device=x.device),
pos[:, :seq_len].expand(len(batch_indices), -1, -1)
)
@register_notrace_function
def create_attention_mask(
@ -430,7 +572,21 @@ def global_pool_naflex(
patch_valid: Optional[torch.Tensor] = None,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
):
) -> torch.Tensor:
"""Global pooling with NaFlex support for masked tokens.
Applies global pooling while respecting patch validity masks to exclude
padding tokens from pooling operations.
Args:
x: Input tensor with shape [B, N, C]
patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens]
pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max')
num_prefix_tokens: Number of prefix tokens (class/register) to exclude from masking
Returns:
Pooled tensor with shape [B, C]
"""
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
# Fall back to standard pooling
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
@ -472,12 +628,16 @@ def global_pool_naflex(
class VisionTransformerFlex(nn.Module):
""" Vision Transformer (Na)Flex
"""Vision Transformer with NaFlex support for flexible input handling.
A flexible implementation of Vision Transformer with:
1. Encapsulated embedding and position encoding in a single module
2. Support for linear patch embedding on pre-patchified inputs
3. Support for variable sequence length / aspect ratio images (NaFlex)
A flexible implementation of Vision Transformer that supports:
- Standard image classification with various pooling strategies
- NaFlex functionality for variable aspect ratios and resolutions
- Linear patch embedding for pre-patchified inputs
- Multiple position embedding strategies (learned, factorized, rope)
- Comprehensive attention masking for efficient batch processing
- Encapsulated embedding and position encoding in FlexEmbeds module
- Compatible with standard ViT checkpoints through checkpoint filtering
"""
def __init__(
@ -494,9 +654,10 @@ class VisionTransformerFlex(nn.Module):
init_values: Optional[float] = None,
class_token: bool = False,
reg_tokens: int = 0,
pos_embed: str = 'learn',
pos_embed: str = 'learned',
pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16),
pos_embed_interp_mode: str = 'bicubic',
pos_embed_ar_preserving: bool = False,
default_img_size: Union[int, Tuple[int, int]] = 256,
dynamic_img_pad: bool = False,
pre_norm: bool = False,
@ -519,44 +680,52 @@ class VisionTransformerFlex(nn.Module):
block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp,
) -> None:
"""
"""Initialize VisionTransformerFlex model.
Args:
patch_size: Patch size.
in_chans: Number of image input channels.
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
reg_tokens: Number of register tokens.
pos_embed: Type of position embedding.
pos_embed_grid_size: Size of position embedding grid.
pos_embed_interp_mode: Interpolation mode for position embedding.
default_img_size: Input image size.
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
num_classes: Number of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
fix_init: Apply weight initialization fix (scaling w/ layer index).
embed_layer_type: Patch embedding implementation (e.g., 'linear', 'conv').
embed_norm_layer: Normalization layer to use / override in patch embed module.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
patch_size: Size of patches for patch embedding
in_chans: Number of input image channels
embed_dim: Transformer embedding dimension
depth: Number of transformer blocks
num_heads: Number of attention heads
mlp_ratio: Ratio of MLP hidden dimension to embedding dimension
qkv_bias: Whether to use bias in query, key, value projections
qk_norm: Whether to normalize query and key projections
proj_bias: Whether to use bias in linear projections
init_values: Layer-scale init values (layer-scale enabled if not None)
class_token: Whether to include a class token
reg_tokens: Number of register tokens to include
pos_embed: Type of position embedding
pos_embed_grid_size: Grid size for position embedding initialization
pos_embed_interp_mode: Interpolation mode for position embedding resizing
pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
default_img_size: Default image size for position embedding grid calculation
dynamic_img_pad: Whether to enable dynamic padding for variable resolution
pre_norm: Whether to apply normalization before attention/MLP layers
final_norm: Whether to apply final normalization before classifier
fc_norm: Whether to normalize features before final classifier
num_classes: Number of classification classes
global_pool: Type of global pooling for final sequence
drop_rate: Dropout rate for classifier
pos_drop_rate: Dropout rate for position embeddings
patch_drop_rate: Dropout rate for patch tokens
proj_drop_rate: Dropout rate for linear projections
attn_drop_rate: Dropout rate for attention weights
drop_path_rate: Stochastic depth drop rate
weight_init: Weight initialization scheme
fix_init: Apply weight initialization fix (scaling w/ layer index)
embed_layer_type: Type of embedding layer ('conv' or 'linear')
embed_norm_layer: Normalization layer for embeddings
norm_layer: Normalization layer for transformer blocks
act_layer: Activation layer for MLP blocks
block_fn: Transformer block implementation class
mlp_layer: MLP implementation class
"""
super().__init__()
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
assert pos_embed in ('', 'none', 'learned', 'factorized')
norm_layer = get_norm_layer(norm_layer) or LayerNorm
embed_norm_layer = get_norm_layer(embed_norm_layer)
act_layer = get_act_layer(act_layer) or nn.GELU
@ -581,6 +750,7 @@ class VisionTransformerFlex(nn.Module):
pos_embed=pos_embed,
pos_embed_grid_size=pos_embed_grid_size,
pos_embed_interp_mode=pos_embed_interp_mode,
pos_embed_ar_preserving=pos_embed_ar_preserving,
pos_drop_rate=pos_drop_rate,
patch_drop_rate=patch_drop_rate,
class_token=class_token,
@ -638,8 +808,9 @@ class VisionTransformerFlex(nn.Module):
if fix_init:
self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, _layer_id):
def fix_init_weight(self) -> None:
"""Apply initialization weight fix with layer-wise scaling."""
def rescale(param: torch.Tensor, _layer_id: int) -> None:
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
@ -647,6 +818,11 @@ class VisionTransformerFlex(nn.Module):
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def init_weights(self, mode: str = '') -> None:
"""Initialize model weights according to specified scheme.
Args:
mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '')
"""
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
named_apply(get_init_weights_vit(mode, head_bias), self)
@ -679,11 +855,24 @@ class VisionTransformerFlex(nn.Module):
@torch.jit.ignore
def no_weight_decay(self) -> Set:
"""Get set of parameter names that should not have weight decay applied.
Returns:
Set of parameter names to skip during weight decay
"""
skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'}
return skip_list
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
"""Get parameter group matcher for optimizer parameter grouping.
Args:
coarse: Whether to use coarse-grained grouping
Returns:
Dictionary mapping group names to regex patterns
"""
return dict(
stem=r'^embeds', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
@ -691,15 +880,31 @@ class VisionTransformerFlex(nn.Module):
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
"""Enable or disable gradient checkpointing for memory efficiency.
Args:
enable: Whether to enable gradient checkpointing
"""
self.grad_checkpointing = enable
if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'):
self.embeds.patch_embed.set_grad_checkpointing(enable)
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
"""Get the classification head module.
Returns:
Classification head module
"""
return self.head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Reset the classification head with new number of classes and pooling.
Args:
num_classes: Number of classes for new classification head
global_pool: Optional new global pooling type
"""
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
@ -762,7 +967,7 @@ class VisionTransformerFlex(nn.Module):
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, attn_mask=mask)
if i in take_indices:
@ -785,7 +990,7 @@ class VisionTransformerFlex(nn.Module):
H, W = self.embeds.dynamic_feat_size((height, width))
else:
H, W = grid_size
intermediates = [y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous()
intermediates = [y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous()
for y in intermediates]
# For dictionary output
@ -795,14 +1000,14 @@ class VisionTransformerFlex(nn.Module):
result_dict['image_intermediates'] = intermediates
if prefix_tokens is not None and return_prefix_tokens:
result_dict['image_intermediates_prefix'] = prefix_tokens
# Only include features if not intermediates_only
if not intermediates_only:
x_final = self.norm(x)
result_dict['image_features'] = x_final
return result_dict
# For non-dictionary output, maintain the original behavior
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
# return_prefix not support in torchscript due to poor type handling
@ -832,7 +1037,7 @@ class VisionTransformerFlex(nn.Module):
# Pass through embedding module with patch coordinate/type support
x = self.embeds(x, patch_coord=patch_coord)
# Apply transformer blocks with masked attention if mask provided
if attn_mask is not None:
# We need to apply blocks one by one with mask
@ -842,7 +1047,7 @@ class VisionTransformerFlex(nn.Module):
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
@ -862,7 +1067,7 @@ class VisionTransformerFlex(nn.Module):
)
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
return x
pool_type = self.global_pool if pool_type is None else pool_type
x = global_pool_naflex(
@ -891,12 +1096,12 @@ class VisionTransformerFlex(nn.Module):
patch_valid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with optional NaFlex support.
Args:
x: Input tensor [B, C, H, W] or pre-patchified tensor [B, N, P*P*C] or NaFlex dict
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
Returns:
Model output tensor
"""
@ -944,7 +1149,7 @@ class VisionTransformerFlex(nn.Module):
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
"""Function imported from vision_transformer.py to maintain compatibility"""
from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm
if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode:
@ -953,7 +1158,7 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
return init_weights_vit_timm
def checkpoint_filter_fn(state_dict, model):
def checkpoint_filter_fn(state_dict: Dict[str, Any], model: VisionTransformerFlex) -> Dict[str, Any]:
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
@ -1049,7 +1254,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
default_cfgs = generate_default_cfgs({
'vit_naflex_base_patch16_gap': _cfg(),
'vit_naflex_base_patch16_map': _cfg(),
'vit_naflex_so400m_patch16_map': _cfg(),
'vit_naflex_base_patch16_siglip': _cfg(),
'vit_naflex_so400m_patch16_siglip': _cfg(),
# sbb model testijg
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
@ -1073,7 +1280,7 @@ default_cfgs = generate_default_cfgs({
})
def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
def _create_vision_transformer_flex(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
model = build_model_with_cfg(
VisionTransformerFlex, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
@ -1083,19 +1290,19 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
@register_model
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs):
def vit_naflex_base_patch16_gap(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
global_pool='avg', reg_tokens=4, fc_norm=True, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
def vit_naflex_base_patch16_map(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
@ -1107,7 +1314,7 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
@register_model
def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
def vit_naflex_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
This model supports:
@ -1117,31 +1324,43 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
"""
model_args = dict(
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True)
qkv_bias=False, reg_tokens=1, global_pool='avg', fc_norm=True)
model = _create_vision_transformer_flex(
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs):
def vit_naflex_base_patch16_siglip(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, act_layer='gelu_tanh', global_pool='map')
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16_siglip', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
"""ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5,
global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh')
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
act_layer='gelu_tanh', global_pool='map')
model = _create_vision_transformer_flex(
'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
'vit_naflex_so400m_patch16_siglip', pretrained=pretrained, **dict(model_args, **kwargs))
return model