""" Vision Transformer (New) An improved version of the Vision Transformer with: 1. Encapsulated embedding and position encoding in a single module 2. Support for linear patch embedding on pre-patchified inputs 3. Support for NaFlex functionality (NaViT + FlexiViT) Based on: - Original Vision Transformer: https://arxiv.org/abs/2010.11929 - FlexiViT: https://arxiv.org/abs/2212.08013 - NaViT: https://arxiv.org/abs/2307.06304 Copyright 2025 """ import logging import math from collections import OrderedDict from functools import partial from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Final, Any, Literal import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert from timm.models._builder import build_model_with_cfg from timm.models._features import feature_take_indices from timm.models._features_fx import register_notrace_function, register_notrace_module from timm.models._registry import register_model, generate_default_cfgs from timm.models._manipulate import checkpoint_seq, named_apply from .vision_transformer import Block, global_pool_nlc _logger = logging.getLogger(__name__) def batch_patchify( x: torch.Tensor, patch_size: Tuple[int, int], pad: bool = True, ) -> Tuple[torch.Tensor, Tuple[int, int]]: B, C, H, W = x.shape ph, pw = patch_size # Ensure the image is divisible by patch size if pad and (H % ph != 0 or W % pw != 0): pad_h = (ph - H % ph) % ph pad_w = (pw - W % pw) % pw x = F.pad(x, (0, pad_w, 0, pad_h)) nh, nw = H // ph, W // pw patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C) # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw return patches, (nh, nw) @register_notrace_module class FlexEmbeds(nn.Module): """ Na(Flex) Embedding module for Vision Transformers This module encapsulates the complete embedding process for Vision Transformers: 1. Patch embedding (via Conv2d or Linear) 2. Class and register token preparation 3. Position embedding addition 4. Pre-normalization (if requested) 5. Dropout application Also supports NaFlex functionality (NaViT + FlexiViT): - Variable aspect ratio and resolution via patch coordinates - Patch type indicators for handling padding tokens in attention - Flexible position embedding interpolation for arbitrary grid sizes Note: Only supports non-overlapping position and register tokens (i.e., position embeddings do not include class/register tokens) The patch embedding can be one of two types: 1. Conv2d-based (default): For standard image inputs [B, C, H, W] 2. Linear-based: For pre-patchified inputs [B, N, P*P*C] """ def __init__( self, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, embed_layer: Optional[str] = None, # 'conv' or 'linear', default is 'linear' input_norm_layer: Optional[Type[nn.Module]] = None, proj_norm_layer: Optional[Type[nn.Module]] = None, final_norm_layer: Optional[Type[nn.Module]] = None, pos_embed: str = 'learned', pos_drop_rate: float = 0., patch_drop_rate: float = 0., class_token: bool = True, reg_tokens: int = 0, bias: bool = True, dynamic_img_pad: bool = False, pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14), pos_embed_interp_mode: str = 'bicubic', default_img_size: Union[int, Tuple[int, int]] = None, ): super().__init__() self.has_class_token = class_token self.num_reg_tokens = reg_tokens self.pos_embed_interp_mode = pos_embed_interp_mode self.patch_size = to_2tuple(patch_size) self.in_chans = in_chans self.embed_dim = embed_dim self.dynamic_img_pad = dynamic_img_pad # Calculate number of prefix tokens self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens += reg_tokens # Create class and register tokens self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None # Calculate grid size and number of patches self.default_img_size: Optional[Tuple[int, int]] = None self.pos_embed_grid_size: Optional[ Tuple[int, int]] = None # Stores the grid size used for learned pos embed init if pos_embed_grid_size is None and default_img_size is not None: self.default_img_size = to_2tuple(default_img_size) self.pos_embed_grid_size = tuple([s // p for s, p in zip(self.default_img_size, self.patch_size)]) elif pos_embed_grid_size is not None: # Use provided pos_embed_grid_size for NaFlex mode self.pos_embed_grid_size = pos_embed_grid_size # Determine patch embedding type (linear or conv2d) if embed_layer == 'linear': # Create linear projection for pre-patchified inputs # Input dimension is patch_size^2 * in_chans patch_dim = self.patch_size[0] * self.patch_size[1] * in_chans self.norm_input = proj_norm_layer(patch_dim) if input_norm_layer else None self.proj = nn.Linear(patch_dim, embed_dim, bias=bias) self.flatten = False self.is_linear = True else: # Default to convolutional patch embedding for image inputs assert input_norm_layer is None self.norm_input = None self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias ) self.flatten = True self.is_linear = False # Create normalization layer after the projection self.norm_proj = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity() # Create position embedding if needed - only for patches, never for prefix tokens if not pos_embed or pos_embed == 'none': self.pos_embed = None self.pos_embed_type = 'none' elif pos_embed == 'rope': self.pos_embed = None self.pos_embed_type = 'rope' # Rotary embeddings will be computed on-the-fly in the forward pass else: # Store position embedding in (1, H, W, dim) format for easier resizing if self.pos_embed_grid_size is not None: h, w = self.pos_embed_grid_size self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02) self.pos_embed_type = 'learned' else: raise ValueError("Cannot initialize position embeddings without grid_size. " "Please provide img_size or pos_embed_grid_size") # Pre-normalization layer (separate from the main norm layer) self.norm_final = final_norm_layer(embed_dim) if final_norm_layer is not None else nn.Identity() # Dropout layers self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: from timm.layers.patch_dropout import PatchDropout self.patch_drop = PatchDropout( patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens, ) else: self.patch_drop = nn.Identity() def feature_info(self, location): """Feature info utility method for feature extraction.""" return dict(num_chs=self.embed_dim, reduction=self.patch_size) def feat_ratio(self, as_scalar=True): """Return the feature reduction ratio (stride).""" if as_scalar: return max(self.patch_size) else: return self.patch_size def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: """ Get grid (feature) size for given image size taking account of dynamic padding. """ if self.dynamic_img_pad: return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) else: return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None): """Forward pass for combined embedding Args: x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C] patch_coord: Optional patch coordinates [B, N, 2] for NaFlex Returns: Embedded tensor with position encoding and class/register tokens applied If patch_type is provided, also returns attention mask """ # Apply patch embedding naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None grid_size: Optional[List[int]] = None B = x.shape[0] if self.is_linear: # Linear embedding path, works with NaFlex mode or standard 2D mode if patch_coord is not None: _assert(x.ndim == 3, 'Expecting patchified input with ndim == 3') # Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches # Calculate the appropriate grid size from coords max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1 max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1 naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] else: _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) if self.norm_input is not None: x = self.norm_input(x) x = self.proj(x) else: assert x.ndim == 4, 'Convolutional input must be 4D' if self.dynamic_img_pad: H, W = x.shape[-2:] pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] x = F.pad(x, (0, pad_w, 0, pad_h)) x = self.proj(x) grid_size = x.shape[-2:] if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC # Apply normalization after flattening x = self.norm_proj(x) if self.pos_embed_type == 'learned': if naflex_grid_sizes is not None: self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) else: assert grid_size is not None self._apply_learned_pos_embed(x, grid_size=grid_size) elif self.pos_embed_type == 'rope': assert False, "ROPE not yet implemented" # Prepare and add class and register tokens to_cat = [] if self.cls_token is not None: to_cat.append(self.cls_token.expand(B, -1, -1)) if self.reg_token is not None: to_cat.append(self.reg_token.expand(B, -1, -1)) # Add tokens to the beginning if to_cat: x = torch.cat(to_cat + [x], dim=1) # Apply final pre-transformer normalization if specified x = self.norm_final(x) # Apply dropouts x = self.patch_drop(self.pos_drop(x)) return x #@torch.compiler.disable() def _apply_learned_naflex_pos_embed( self, x: torch.Tensor, naflex_grid_sizes: List[Tuple[int, int]], ): # Handle each batch element separately with its own grid size orig_h, orig_w = self.pos_embed.shape[1:3] pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W def _interp(_size): if (_size[0] == orig_h) and (_size[1] == orig_w): pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: pos_embed_flat = F.interpolate( pos_embed_nchw, size=_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, ).flatten(2).transpose(1, 2) return pos_embed_flat.to(dtype=x.dtype) # FIXME leaving alternative code commented here for now for comparisons # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {} # for i, s in enumerate(naflex_grid_sizes): # if s in pos_embed_cache: # pos_embed_flat = pos_embed_cache[s] # else: # pos_embed_flat = _interp(s) # pos_embed_cache[s] = pos_embed_flat # # seq_len = min(x.shape[1], pos_embed_flat.shape[1]) # x[i, :seq_len] += pos_embed_flat[0, :seq_len] # Determine unique grid sizes size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, k in enumerate(naflex_grid_sizes): # k = h << 16 | w # FIXME can get jit compat with this size_to_indices.setdefault(k, []).append(bi) for k, batch_indices in size_to_indices.items(): # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this # Interpolate only once for this (h, w) pos_embed_flat = _interp(k) seq_len = min(x.shape[1], pos_embed_flat.shape[1]) x[:, :seq_len].index_add_( 0, torch.as_tensor(batch_indices, device=x.device), pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1) ) def _apply_learned_pos_embed( self, x: torch.Tensor, grid_size: List[int], ): orig_h, orig_w = self.pos_embed.shape[1:3] if grid_size[0] == orig_h or grid_size[1] == orig_w: # No resize needed, just flatten pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: # Resize if needed - directly using F.interpolate pos_embed_flat = F.interpolate( self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W size=grid_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, ).flatten(2).transpose(1, 2) pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) x.add_(pos_embed_flat) @register_notrace_function def create_attention_mask( patch_valid: torch.Tensor, num_prefix_tokens: int = 0, symmetric: bool = True, q_len: Optional[int] = None, dtype: torch.dtype = torch.float32, ) -> Optional[torch.Tensor]: """Creates an attention mask from patch validity information. Supports two modes controlled by `symmetric`: 1. `symmetric=True` (default): Creates a symmetric mask of shape [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if both token i and token j are valid. Suitable for standard self-attention. 2. `symmetric=False`: Creates a potentially non-square mask of shape [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if the key/value token k is valid. Query token validity is not checked in the mask itself. Useful for cross-attention or specific self-attention implementations `q_len` can be specified. Used for NaFlex mode to handle variable token counts and padding tokens. Args: patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding. num_prefix_tokens: Number of prefix tokens (class token, register tokens) to prepend, which are always considered valid. symmetric: If True, create a symmetric mask. If False, create an expanded mask based only on key/value validity. q_len: Query sequence length override. Only used when `symmetric` is False. Defaults to the key/value sequence length (`kv_len`) if None. dtype: Dtype of the output attention mask (e.g., torch.float32). Returns: Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked). Shape is [B, 1, seq_len, seq_len] if symmetric=True, or [B, 1, q_len, kv_len] if symmetric=False. """ if patch_valid is None: return None patch_valid = patch_valid.bool() # Ensure boolean type B, N = patch_valid.shape kv_len = N # Initial key/value length is the number of patches # Prepend prefix tokens if any if num_prefix_tokens > 0: # Create prefix validity tensor on the same device/dtype base as patch_valid prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool) # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N] patch_valid = torch.cat([prefix_valid, patch_valid], dim=1) kv_len += num_prefix_tokens # Update total key/value sequence length if symmetric: # Symmetric mask is True where BOTH query and key are valid mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1) mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len] else: # Expanded mask q_len = q_len or kv_len mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len) # Create the float mask and apply masking using additive mask convention mask_float = torch.zeros_like(mask_bool, dtype=dtype) # Fill with negative infinity where mask_bool is False (masked positions) mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min) return mask_float @register_notrace_function def global_pool_naflex( x: torch.Tensor, patch_valid: Optional[torch.Tensor] = None, pool_type: str = 'token', num_prefix_tokens: int = 1, ): if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'): # Fall back to standard pooling x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens) return x # For NaFlex mode, we need to apply masked pooling to exclude padding tokens # Extract only the patch part of the mask (excluding prefix tokens) if num_prefix_tokens > 0: # Apply the mask to extract only valid tokens x = x[:, num_prefix_tokens:] # prefix tokens not included in pooling patch_valid_float = patch_valid.to(x.dtype) if pool_type == 'avg': # Compute masked average pooling, sum valid tokens and divide by count of valid tokens masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) pooled = masked_sums / valid_counts return pooled elif pool_type == 'avgmax': # For avgmax, compute masked average and masked max masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) masked_avg = masked_sums / valid_counts # For max pooling we set masked positions to large negative value masked_x = x.clone() masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min masked_max = masked_x.amax(dim=1) # Combine average and max return 0.5 * (masked_avg + masked_max) elif pool_type == 'max': # For max pooling we set masked positions to large negative value masked_x = x.clone() masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min return masked_x.amax(dim=1) else: assert False class VisionTransformerFlex(nn.Module): """ Vision Transformer (Na)Flex A flexible implementation of Vision Transformer with: 1. Encapsulated embedding and position encoding in a single module 2. Support for linear patch embedding on pre-patchified inputs 3. Support for variable sequence length / aspect ratio images (NaFlex) """ def __init__( self, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4., qkv_bias: bool = True, qk_norm: bool = False, proj_bias: bool = True, init_values: Optional[float] = None, class_token: bool = False, reg_tokens: int = 0, pos_embed: str = 'learn', pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16), pos_embed_interp_mode: str = 'bicubic', default_img_size: Union[int, Tuple[int, int]] = 256, dynamic_img_pad: bool = False, pre_norm: bool = False, final_norm: bool = True, fc_norm: Optional[bool] = None, num_classes: int = 1000, global_pool: str = 'map', drop_rate: float = 0., pos_drop_rate: float = 0., patch_drop_rate: float = 0., proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., weight_init: str = '', fix_init: bool = True, embed_layer_type: str = 'linear', embed_norm_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, mlp_layer: Type[nn.Module] = Mlp, ) -> None: """ Args: patch_size: Patch size. in_chans: Number of image input channels. embed_dim: Transformer embedding dimension. depth: Depth of transformer. num_heads: Number of attention heads. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: Enable bias for qkv projections if True. init_values: Layer-scale init values (layer-scale enabled if not None). class_token: Use class token. reg_tokens: Number of register tokens. pos_embed: Type of position embedding. pos_embed_grid_size: Size of position embedding grid. pos_embed_interp_mode: Interpolation mode for position embedding. default_img_size: Input image size. pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT). final_norm: Enable norm after transformer blocks, before head (standard in most ViT). fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'. num_classes: Number of classes for classification head. global_pool: Type of global pooling for final sequence (default: 'token'). drop_rate: Head dropout rate. pos_drop_rate: Position embedding dropout rate. attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. weight_init: Weight initialization scheme. fix_init: Apply weight initialization fix (scaling w/ layer index). embed_layer_type: Patch embedding implementation (e.g., 'linear', 'conv'). embed_norm_layer: Normalization layer to use / override in patch embed module. norm_layer: Normalization layer. act_layer: MLP activation layer. block_fn: Transformer block layer. """ super().__init__() assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') assert class_token or global_pool != 'token' assert pos_embed in ('', 'none', 'learn') norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) embed_norm_layer = get_norm_layer(embed_norm_layer) act_layer = get_act_layer(act_layer) or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens += reg_tokens self.num_reg_tokens = reg_tokens self.has_class_token = class_token self.grad_checkpointing = False # Initialize embedding module (includes patch, position embedding, and class/reg tokens) # VisionTransformerEmbeds is always used - handles both linear and conv embedding self.embeds = FlexEmbeds( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, embed_layer=embed_layer_type, proj_norm_layer=embed_norm_layer, final_norm_layer=norm_layer if pre_norm else None, pos_embed=pos_embed, pos_embed_grid_size=pos_embed_grid_size, pos_embed_interp_mode=pos_embed_interp_mode, pos_drop_rate=pos_drop_rate, patch_drop_rate=patch_drop_rate, class_token=class_token, reg_tokens=reg_tokens, default_img_size=default_img_size, dynamic_img_pad=dynamic_img_pad, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) ) # Transformer blocks dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, proj_bias=proj_bias, init_values=init_values, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, ) for i in range(depth)]) # Feature info for downstream tasks patch_reduction = to_2tuple(patch_size) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_reduction) for i in range(depth)] self.norm = norm_layer(embed_dim) if final_norm and not fc_norm else nn.Identity() # Classifier Head if global_pool == 'map': self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, act_layer=act_layer, ) else: self.attn_pool = None self.fc_norm = norm_layer(embed_dim) if final_norm and fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) if fix_init: self.fix_init_weight() def fix_init_weight(self): def rescale(param, _layer_id): param.div_(math.sqrt(2.0 * _layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def init_weights(self, mode: str = '') -> None: assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. named_apply(get_init_weights_vit(mode, head_bias), self) @torch.jit.ignore() def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: # Custom loading for the new model structure from .vision_transformer import _load_weights as _orig_load_weights def _load_weights_adapter(model, checkpoint_path, prefix=''): """Adapter function to handle the different model structure""" state_dict = torch.load(checkpoint_path, map_location='cpu') if isinstance(state_dict, dict) and 'state_dict' in state_dict: state_dict = state_dict['state_dict'] # Map original keys to new structure for k in list(state_dict.keys()): if k.startswith('cls_token'): state_dict['embeds.' + k] = state_dict.pop(k) elif k.startswith('reg_token'): state_dict['embeds.' + k] = state_dict.pop(k) elif k.startswith('pos_embed'): state_dict['embeds.' + k] = state_dict.pop(k) elif k.startswith('patch_embed'): state_dict['embeds.' + k[12:]] = state_dict.pop(k) return _orig_load_weights(model, state_dict, prefix) _load_weights_adapter(self, checkpoint_path, prefix) @torch.jit.ignore def no_weight_decay(self) -> Set: skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'} return skip_list @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict: return dict( stem=r'^embeds', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: self.grad_checkpointing = enable if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'): self.embeds.patch_embed.set_grad_checkpointing(enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') if global_pool == 'map' and self.attn_pool is None: assert False, "Cannot currently add attention pooling in reset_classifier()." elif global_pool != 'map' and self.attn_pool is not None: self.attn_pool = None # remove attention pooling self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_intermediates( self, x: torch.Tensor, indices: Optional[Union[int, List[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, output_dict: bool = False, patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: """ Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence return_prefix_tokens: Return both prefix and spatial intermediate tokens norm: Apply norm layer to all intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs intermediates_only: Only return intermediate features output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex mask: Optional attention mask Returns: A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') """ # FIXME unfinished / untested assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] take_indices, max_index = feature_take_indices(len(self.blocks), indices) # Create attention mask if patch_type is provided and mask is not if mask is None and patch_valid is not None: mask = create_attention_mask(patch_valid, self.num_prefix_tokens, x.dtype) # Forward pass through embedding x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid) # Forward pass through blocks if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.blocks else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): x = blk(x, attn_mask=mask) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) # Process intermediates if self.num_prefix_tokens: # split prefix (e.g. class, distill) and spatial feature tokens prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] else: prefix_tokens = None if reshape: # reshape to BCHW output format grid_size = self.embeds.pos_embed_grid_size if hasattr(self.embeds, 'dynamic_feat_size') and len(x.shape) >= 4: _, height, width, _ = x.shape if len(x.shape) == 4 else (None, *x.shape[-3:-1], None) H, W = self.embeds.dynamic_feat_size((height, width)) else: H, W = grid_size intermediates = [y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] # For dictionary output if output_dict: result_dict = {} # Intermediates are always included result_dict['image_intermediates'] = intermediates if prefix_tokens is not None and return_prefix_tokens: result_dict['image_intermediates_prefix'] = prefix_tokens # Only include features if not intermediates_only if not intermediates_only: x_final = self.norm(x) result_dict['image_features'] = x_final return result_dict # For non-dictionary output, maintain the original behavior if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(intermediates, prefix_tokens)) if intermediates_only: return intermediates x = self.norm(x) return x, intermediates def forward_features( self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if attn_mask is None and patch_valid is not None: attn_mask = create_attention_mask( patch_valid, num_prefix_tokens=self.num_prefix_tokens, dtype=x.dtype ) # Pass through embedding module with patch coordinate/type support x = self.embeds(x, patch_coord=patch_coord) # Apply transformer blocks with masked attention if mask provided if attn_mask is not None: # We need to apply blocks one by one with mask for blk in self.blocks: x = blk(x, attn_mask=attn_mask) elif self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) return x def _pool( self, x: torch.Tensor, pool_type: Optional[str] = None, patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.attn_pool is not None: # For attention pooling, we need to pass the mask for NaFlex models attn_mask = create_attention_mask( patch_valid, symmetric=False, q_len=1, dtype=x.dtype, ) x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask) return x pool_type = self.global_pool if pool_type is None else pool_type x = global_pool_naflex( x, patch_valid, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens, ) return x def forward_head( self, x: torch.Tensor, pre_logits: bool = False, patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = self._pool(x, patch_valid=patch_valid) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) def forward( self, x: Union[torch.Tensor, Dict[str, torch.Tensor]], patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with optional NaFlex support. Args: x: Input tensor [B, C, H, W] or pre-patchified tensor [B, N, P*P*C] or NaFlex dict patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex Returns: Model output tensor """ if isinstance(x, Dict): # Handle dictionary input from NaFlex collator patch_coord = x['patch_coord'] patch_valid = x['patch_valid'] patches = x['patches'] # DEBUG, reconstruct patches # for i in range(len(patches)): # patch = patches[i][patch_valid[i]] # h = (patch_coord[i, :, 0].max() + 1).item() # w = (patch_coord[i, :, 1].max() + 1).item() # patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3) # patch = patch.reshape(3, h*16, w*16) # from torchvision.utils import save_image # save_image(patch, f'patch_{i}.jpg', normalize=True) else: patches = x # Create attention mask if patch_type is provided attn_mask = create_attention_mask( patch_valid, num_prefix_tokens=self.num_prefix_tokens, dtype=patches.dtype, ) # Forward features with mask x = self.forward_features( patches, patch_coord=patch_coord, patch_valid=patch_valid, attn_mask=attn_mask, ) # Pass mask to forward_head for masked pooling x = self.forward_head( x, patch_valid=patch_valid, ) return x def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: """Function imported from vision_transformer.py to maintain compatibility""" from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: return init_weights_vit_moco else: return init_weights_vit_timm def checkpoint_filter_fn(state_dict, model): """Handle state dict conversion from original ViT to the new version with combined embedding.""" from .vision_transformer import checkpoint_filter_fn as orig_filter_fn # Handle CombinedEmbed module pattern out_dict = {} for k, v in state_dict.items(): # Convert tokens and embeddings to combined_embed structure if k == 'pos_embed': # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C) if hasattr(model.embeds, 'pos_embed') and v.ndim == 3: num_cls_token = 0 num_reg_token = 0 if 'reg_token' in state_dict: num_reg_token = state_dict['reg_token'].shape[1] if 'cls_token' in state_dict: num_cls_token = state_dict['cls_token'].shape[1] num_prefix_tokens = num_cls_token + num_reg_token # Original format is (1, N, C), need to reshape to (1, H, W, C) num_patches = v.shape[1] num_patches_no_prefix = num_patches - num_prefix_tokens grid_size_no_prefix = math.sqrt(num_patches_no_prefix) grid_size = math.sqrt(num_patches) if (grid_size_no_prefix != grid_size and ( grid_size_no_prefix.is_integer() and not grid_size.is_integer())): # make a decision, did the pos_embed of the original include the prefix tokens? num_patches = num_patches_no_prefix cls_token_emb = v[:, 0:num_cls_token] if cls_token_emb.numel(): state_dict['cls_token'] += cls_token_emb reg_token_emb = v[:, num_cls_token:num_reg_token] if reg_token_emb.numel(): state_dict['reg_token'] += reg_token_emb v = v[:, num_prefix_tokens:] grid_size = grid_size_no_prefix grid_size = int(grid_size) # Check if it's a perfect square for a standard grid if grid_size * grid_size == num_patches: # Reshape from (1, N, C) to (1, H, W, C) v = v.reshape(1, grid_size, grid_size, v.shape[2]) else: # Not a square grid, we need to get the actual dimensions if hasattr(model.embeds.patch_embed, 'grid_size'): h, w = model.embeds.patch_embed.grid_size if h * w == num_patches: # We have the right dimensions v = v.reshape(1, h, w, v.shape[2]) else: # Dimensions don't match, use interpolation _logger.warning( f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. " f"Using default initialization and will resize in forward pass." ) # Keep v as is, the forward pass will handle resizing out_dict['embeds.pos_embed'] = v elif k == 'cls_token': out_dict['embeds.cls_token'] = v elif k == 'reg_token': out_dict['embeds.reg_token'] = v # Convert patch_embed.X to embeds.patch_embed.X elif k.startswith('patch_embed.'): suffix = k[12:] if suffix == 'proj.weight': # FIXME confirm patchify memory layout across use cases v = v.permute(0, 2, 3, 1).flatten(1) new_key = 'embeds.' + suffix out_dict[new_key] = v else: out_dict[k] = v return out_dict def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: return { 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None, 'crop_pct': 0.95, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'embeds.proj', 'classifier': 'head', 'license': 'apache-2.0', **kwargs, } default_cfgs = generate_default_cfgs({ 'vit_naflex_base_patch16_gap': _cfg(), 'vit_naflex_base_patch16_map': _cfg(), 'vit_naflex_so400m_patch16_map': _cfg(), # sbb model testijg 'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg( hf_hub_id='timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k', input_size=(3, 256, 256), crop_pct=0.95), 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k': _cfg( hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k', input_size=(3, 256, 256), crop_pct=1.0), 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k': _cfg( hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k', input_size=(3, 384, 384), crop_pct=1.0), 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k': _cfg( hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k', input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'), # traditional vit testing 'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'), 'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'), }) def _create_vision_transformer_flex(variant, pretrained=False, **kwargs): model = build_model_with_cfg( VisionTransformerFlex, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs, ) return model @register_model def vit_naflex_base_patch16_gap(pretrained=False, **kwargs): """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs) model = _create_vision_transformer_flex( 'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_naflex_base_patch16_map(pretrained=False, **kwargs): """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, global_pool='map', reg_tokens=1) model = _create_vision_transformer_flex( 'vit_naflex_base_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs): """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. This model supports: 1. Variable aspect ratios and resolutions via patch coordinates 2. Position embedding interpolation for arbitrary grid sizes 3. Explicit patch coordinates and valid token masking """ model_args = dict( patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True) model = _create_vision_transformer_flex( 'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_naflex_base_patch16(pretrained: bool = False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14)) model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs): """ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions. """ model_args = dict( patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5, global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh') model = _create_vision_transformer_flex( 'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs)) return model