diff --git a/timm/data/naflex_transforms.py b/timm/data/naflex_transforms.py index fca0429d..d70255a8 100644 --- a/timm/data/naflex_transforms.py +++ b/timm/data/naflex_transforms.py @@ -749,11 +749,9 @@ def patchify_image( # Ensure the image is divisible by patch size if pad and (h % ph != 0 or w % pw != 0): - new_h = math.ceil(h / ph) * ph - new_w = math.ceil(w / pw) * pw - padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype) - padded_img[:, :h, :w] = img - img = padded_img + pad_h = (ph - h % ph) % ph # amount to add on bottom + pad_w = (pw - w % pw) % pw # amount to add on right + img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h)) c, h, w = img.shape # Calculate number of patches in each dimension diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py index 3c52ce0b..393b7c0b 100644 --- a/timm/models/vision_transformer_flex.py +++ b/timm/models/vision_transformer_flex.py @@ -24,7 +24,16 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert +from timm.layers import ( + AttentionPoolLatent, + Mlp, + to_2tuple, + get_act_layer, + get_norm_layer, + LayerNorm, + LayerType, + _assert, +) from timm.models._builder import build_model_with_cfg from timm.models._features import feature_take_indices from timm.models._features_fx import register_notrace_function, register_notrace_module @@ -59,27 +68,46 @@ def batch_patchify( @register_notrace_module class FlexEmbeds(nn.Module): - """ Na(Flex) Embedding module for Vision Transformers + """NaFlex Embedding module for Vision Transformers. + + This module encapsulates the complete embedding process for Vision Transformers, + supporting both standard and NaFlex (NaViT + FlexiViT) functionality: - This module encapsulates the complete embedding process for Vision Transformers: 1. Patch embedding (via Conv2d or Linear) 2. Class and register token preparation - 3. Position embedding addition + 3. Position embedding addition with interpolation support 4. Pre-normalization (if requested) 5. Dropout application - Also supports NaFlex functionality (NaViT + FlexiViT): + NaFlex capabilities include: - Variable aspect ratio and resolution via patch coordinates - Patch type indicators for handling padding tokens in attention - Flexible position embedding interpolation for arbitrary grid sizes - - Note: Only supports non-overlapping position and register tokens - (i.e., position embeddings do not include class/register tokens) + - Support for factorized position embeddings The patch embedding can be one of two types: - 1. Conv2d-based (default): For standard image inputs [B, C, H, W] - 2. Linear-based: For pre-patchified inputs [B, N, P*P*C] + - Conv2d-based (default): For standard image inputs [B, C, H, W] + - Linear-based: For pre-patchified inputs [B, N, P*P*C] + Args: + patch_size: Size of patches for patch embedding + in_chans: Number of input image channels + embed_dim: Dimensionality of patch embedding + embed_layer: Type of embedding layer ('conv' or 'linear') + input_norm_layer: Normalization layer applied to input (linear mode only) + proj_norm_layer: Normalization layer applied after projection + final_norm_layer: Final normalization layer before output + pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none') + pos_drop_rate: Dropout rate for position embeddings + patch_drop_rate: Dropout rate for patch tokens + class_token: Whether to include a class token + reg_tokens: Number of register tokens to include + bias: Whether to use bias in projection layers + dynamic_img_pad: Whether to enable dynamic padding for variable resolution + pos_embed_grid_size: Grid size for position embedding initialization + pos_embed_interp_mode: Interpolation mode for position embedding resizing + pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation + default_img_size: Default image size for position embedding grid calculation """ def __init__( @@ -100,12 +128,14 @@ class FlexEmbeds(nn.Module): dynamic_img_pad: bool = False, pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14), pos_embed_interp_mode: str = 'bicubic', + pos_embed_ar_preserving: bool = False, default_img_size: Union[int, Tuple[int, int]] = None, ): super().__init__() self.has_class_token = class_token self.num_reg_tokens = reg_tokens self.pos_embed_interp_mode = pos_embed_interp_mode + self.pos_embed_ar_preserving = pos_embed_ar_preserving self.patch_size = to_2tuple(patch_size) self.in_chans = in_chans self.embed_dim = embed_dim @@ -153,22 +183,28 @@ class FlexEmbeds(nn.Module): self.norm_proj = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity() # Create position embedding if needed - only for patches, never for prefix tokens + if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None: + raise ValueError( + "Cannot initialize position embeddings without grid_size." + "Please provide img_size or pos_embed_grid_size.") + self.pos_embed: Optional[torch.Tensor] = None + self.pos_embed_y: Optional[torch.Tensor] = None + self.pos_embed_x: Optional[torch.Tensor] = None if not pos_embed or pos_embed == 'none': - self.pos_embed = None self.pos_embed_type = 'none' elif pos_embed == 'rope': - self.pos_embed = None self.pos_embed_type = 'rope' # Rotary embeddings will be computed on-the-fly in the forward pass + elif pos_embed == 'factorized': + h, w = self.pos_embed_grid_size + self.pos_embed_type = 'factorized' + self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim) * .02) + self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim) * .02) else: # Store position embedding in (1, H, W, dim) format for easier resizing - if self.pos_embed_grid_size is not None: - h, w = self.pos_embed_grid_size - self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02) - self.pos_embed_type = 'learned' - else: - raise ValueError("Cannot initialize position embeddings without grid_size. " - "Please provide img_size or pos_embed_grid_size") + h, w = self.pos_embed_grid_size + self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02) + self.pos_embed_type = 'learned' # Pre-normalization layer (separate from the main norm layer) self.norm_final = final_norm_layer(embed_dim) if final_norm_layer is not None else nn.Identity() @@ -184,35 +220,57 @@ class FlexEmbeds(nn.Module): else: self.patch_drop = nn.Identity() - def feature_info(self, location): - """Feature info utility method for feature extraction.""" + def feature_info(self, location) -> Dict[str, Any]: + """Get feature information for feature extraction. + + Args: + location: Feature extraction location identifier + + Returns: + Dictionary containing feature channel count and reduction factor + """ return dict(num_chs=self.embed_dim, reduction=self.patch_size) - def feat_ratio(self, as_scalar=True): - """Return the feature reduction ratio (stride).""" + def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]: + """Get the feature reduction ratio (stride) of the patch embedding. + + Args: + as_scalar: Whether to return the maximum dimension as a scalar + + Returns: + Feature reduction ratio as scalar or tuple + """ if as_scalar: return max(self.patch_size) else: return self.patch_size def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: - """ Get grid (feature) size for given image size taking account of dynamic padding. + """Calculate grid (feature) size for given image size. + + Takes into account dynamic padding when enabled. + + Args: + img_size: Input image size as (height, width) + + Returns: + Grid size as (grid_height, grid_width) """ if self.dynamic_img_pad: return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) else: return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] - def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None): - """Forward pass for combined embedding + def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass for patch embedding with position encoding. Args: - x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C] - patch_coord: Optional patch coordinates [B, N, 2] for NaFlex + x: Input tensor [B, C, H, W] for conv mode or [B, N, P*P*C] for pre-patchified linear mode + patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode Returns: - Embedded tensor with position encoding and class/register tokens applied - If patch_type is provided, also returns attention mask + Embedded tensor with position encoding and class/register tokens applied. + Shape: [B, num_prefix_tokens + N, embed_dim] """ # Apply patch embedding naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None @@ -259,6 +317,9 @@ class FlexEmbeds(nn.Module): else: assert grid_size is not None self._apply_learned_pos_embed(x, grid_size=grid_size) + elif self.pos_embed_type == 'factorized': + if naflex_grid_sizes is not None: + self._apply_factorized_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) elif self.pos_embed_type == 'rope': assert False, "ROPE not yet implemented" @@ -285,22 +346,38 @@ class FlexEmbeds(nn.Module): self, x: torch.Tensor, naflex_grid_sizes: List[Tuple[int, int]], - ): + ) -> None: + """Apply learned position embeddings to NaFlex batch in-place. + + Interpolates learned position embeddings for each sample in the batch + based on their individual grid sizes. + + Args: + x: Input tensor to add position embeddings to + naflex_grid_sizes: List of (height, width) grid sizes for each batch element + """ # Handle each batch element separately with its own grid size orig_h, orig_w = self.pos_embed.shape[1:3] pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W - def _interp(_size): - if (_size[0] == orig_h) and (_size[1] == orig_w): + def _interp2d(size): + """ + Return a flattened positional-embedding grid at an arbitrary spatial resolution. + + Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into + a (1, H*W, C) sequence that matches the requested size. + """ + if (size[0] == orig_h) and (size[1] == orig_w): pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: + _interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size pos_embed_flat = F.interpolate( pos_embed_nchw, - size=_size, + size=_interp_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, - ).flatten(2).transpose(1, 2) + )[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2) return pos_embed_flat.to(dtype=x.dtype) # FIXME leaving alternative code commented here for now for comparisons @@ -315,7 +392,7 @@ class FlexEmbeds(nn.Module): # seq_len = min(x.shape[1], pos_embed_flat.shape[1]) # x[i, :seq_len] += pos_embed_flat[0, :seq_len] - # Determine unique grid sizes + # Determine unique grid sizes to avoid duplicate interpolation size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, k in enumerate(naflex_grid_sizes): # k = h << 16 | w # FIXME can get jit compat with this @@ -324,7 +401,7 @@ class FlexEmbeds(nn.Module): for k, batch_indices in size_to_indices.items(): # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this # Interpolate only once for this (h, w) - pos_embed_flat = _interp(k) + pos_embed_flat = _interp2d(k) seq_len = min(x.shape[1], pos_embed_flat.shape[1]) x[:, :seq_len].index_add_( 0, @@ -336,7 +413,15 @@ class FlexEmbeds(nn.Module): self, x: torch.Tensor, grid_size: List[int], - ): + ) -> None: + """Apply learned position embeddings to standard batch in-place. + + Interpolates learned position embeddings to match the specified grid size. + + Args: + x: Input tensor to add position embeddings to + grid_size: Target grid size as [height, width] + """ orig_h, orig_w = self.pos_embed.shape[1:3] if grid_size[0] == orig_h or grid_size[1] == orig_w: # No resize needed, just flatten @@ -354,6 +439,63 @@ class FlexEmbeds(nn.Module): x.add_(pos_embed_flat) + def _apply_factorized_naflex_pos_embed( + self, + x: torch.Tensor, + naflex_grid_sizes: List[Tuple[int, int]], + ) -> None: + """Apply factorized position embeddings to NaFlex batch in-place. + + Uses separate Y and X position embedding tables that are interpolated + and combined for each sample's grid size. + + Args: + x: Input tensor to add position embeddings to + naflex_grid_sizes: List of (height, width) grid sizes for each batch element + """ + assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample + + # Handle each batch element separately with its own grid size + orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1] + + # bucket samples that share the same (H,W) so we build each grid once + size_to_indices: Dict[Tuple[int, int], List[int]] = {} + for bi, k in enumerate(naflex_grid_sizes): + size_to_indices.setdefault(k, []).append(bi) + + def _interp1d(table: torch.Tensor, length: int) -> torch.Tensor: + """ + Resample a 1-D positional-embedding table to specified length + and return it in (1, L, C) layout, dtype matching x. + """ + return F.interpolate( + table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out) + size=length, + mode='linear', + align_corners=False, + ).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C) + + for k, batch_indices in size_to_indices.items(): + target_h, target_w = k + if self.pos_embed_ar_preserving: + len_y = len_x = max(target_h, target_w) + else: + len_y, len_x = target_h, target_w + + pe_y = _interp1d(self.pos_embed_y, len_y)[:, :target_h] # (1,H,C) + pe_x = _interp1d(self.pos_embed_x, len_x)[:, :target_w] # (1,W,C) + + # Broadcast, add and flatten to sequence layout (row major) + pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C) + pos = pos.flatten(1, 2) + + seq_len = min(x.shape[1], pos.shape[1]) + x[:, :seq_len].index_add_( + 0, + torch.as_tensor(batch_indices, device=x.device), + pos[:, :seq_len].expand(len(batch_indices), -1, -1) + ) + @register_notrace_function def create_attention_mask( @@ -430,7 +572,21 @@ def global_pool_naflex( patch_valid: Optional[torch.Tensor] = None, pool_type: str = 'token', num_prefix_tokens: int = 1, -): +) -> torch.Tensor: + """Global pooling with NaFlex support for masked tokens. + + Applies global pooling while respecting patch validity masks to exclude + padding tokens from pooling operations. + + Args: + x: Input tensor with shape [B, N, C] + patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens] + pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max') + num_prefix_tokens: Number of prefix tokens (class/register) to exclude from masking + + Returns: + Pooled tensor with shape [B, C] + """ if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'): # Fall back to standard pooling x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens) @@ -472,12 +628,16 @@ def global_pool_naflex( class VisionTransformerFlex(nn.Module): - """ Vision Transformer (Na)Flex + """Vision Transformer with NaFlex support for flexible input handling. - A flexible implementation of Vision Transformer with: - 1. Encapsulated embedding and position encoding in a single module - 2. Support for linear patch embedding on pre-patchified inputs - 3. Support for variable sequence length / aspect ratio images (NaFlex) + A flexible implementation of Vision Transformer that supports: + - Standard image classification with various pooling strategies + - NaFlex functionality for variable aspect ratios and resolutions + - Linear patch embedding for pre-patchified inputs + - Multiple position embedding strategies (learned, factorized, rope) + - Comprehensive attention masking for efficient batch processing + - Encapsulated embedding and position encoding in FlexEmbeds module + - Compatible with standard ViT checkpoints through checkpoint filtering """ def __init__( @@ -494,9 +654,10 @@ class VisionTransformerFlex(nn.Module): init_values: Optional[float] = None, class_token: bool = False, reg_tokens: int = 0, - pos_embed: str = 'learn', + pos_embed: str = 'learned', pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16), pos_embed_interp_mode: str = 'bicubic', + pos_embed_ar_preserving: bool = False, default_img_size: Union[int, Tuple[int, int]] = 256, dynamic_img_pad: bool = False, pre_norm: bool = False, @@ -519,44 +680,52 @@ class VisionTransformerFlex(nn.Module): block_fn: Type[nn.Module] = Block, mlp_layer: Type[nn.Module] = Mlp, ) -> None: - """ + """Initialize VisionTransformerFlex model. + Args: - patch_size: Patch size. - in_chans: Number of image input channels. - embed_dim: Transformer embedding dimension. - depth: Depth of transformer. - num_heads: Number of attention heads. - mlp_ratio: Ratio of mlp hidden dim to embedding dim. - qkv_bias: Enable bias for qkv projections if True. - init_values: Layer-scale init values (layer-scale enabled if not None). - class_token: Use class token. - reg_tokens: Number of register tokens. - pos_embed: Type of position embedding. - pos_embed_grid_size: Size of position embedding grid. - pos_embed_interp_mode: Interpolation mode for position embedding. - default_img_size: Input image size. - pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT). - final_norm: Enable norm after transformer blocks, before head (standard in most ViT). - fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'. - num_classes: Number of classes for classification head. - global_pool: Type of global pooling for final sequence (default: 'token'). - drop_rate: Head dropout rate. - pos_drop_rate: Position embedding dropout rate. - attn_drop_rate: Attention dropout rate. - drop_path_rate: Stochastic depth rate. - weight_init: Weight initialization scheme. - fix_init: Apply weight initialization fix (scaling w/ layer index). - embed_layer_type: Patch embedding implementation (e.g., 'linear', 'conv'). - embed_norm_layer: Normalization layer to use / override in patch embed module. - norm_layer: Normalization layer. - act_layer: MLP activation layer. - block_fn: Transformer block layer. + patch_size: Size of patches for patch embedding + in_chans: Number of input image channels + embed_dim: Transformer embedding dimension + depth: Number of transformer blocks + num_heads: Number of attention heads + mlp_ratio: Ratio of MLP hidden dimension to embedding dimension + qkv_bias: Whether to use bias in query, key, value projections + qk_norm: Whether to normalize query and key projections + proj_bias: Whether to use bias in linear projections + init_values: Layer-scale init values (layer-scale enabled if not None) + class_token: Whether to include a class token + reg_tokens: Number of register tokens to include + pos_embed: Type of position embedding + pos_embed_grid_size: Grid size for position embedding initialization + pos_embed_interp_mode: Interpolation mode for position embedding resizing + pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation + default_img_size: Default image size for position embedding grid calculation + dynamic_img_pad: Whether to enable dynamic padding for variable resolution + pre_norm: Whether to apply normalization before attention/MLP layers + final_norm: Whether to apply final normalization before classifier + fc_norm: Whether to normalize features before final classifier + num_classes: Number of classification classes + global_pool: Type of global pooling for final sequence + drop_rate: Dropout rate for classifier + pos_drop_rate: Dropout rate for position embeddings + patch_drop_rate: Dropout rate for patch tokens + proj_drop_rate: Dropout rate for linear projections + attn_drop_rate: Dropout rate for attention weights + drop_path_rate: Stochastic depth drop rate + weight_init: Weight initialization scheme + fix_init: Apply weight initialization fix (scaling w/ layer index) + embed_layer_type: Type of embedding layer ('conv' or 'linear') + embed_norm_layer: Normalization layer for embeddings + norm_layer: Normalization layer for transformer blocks + act_layer: Activation layer for MLP blocks + block_fn: Transformer block implementation class + mlp_layer: MLP implementation class """ super().__init__() assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') assert class_token or global_pool != 'token' - assert pos_embed in ('', 'none', 'learn') - norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + assert pos_embed in ('', 'none', 'learned', 'factorized') + norm_layer = get_norm_layer(norm_layer) or LayerNorm embed_norm_layer = get_norm_layer(embed_norm_layer) act_layer = get_act_layer(act_layer) or nn.GELU @@ -581,6 +750,7 @@ class VisionTransformerFlex(nn.Module): pos_embed=pos_embed, pos_embed_grid_size=pos_embed_grid_size, pos_embed_interp_mode=pos_embed_interp_mode, + pos_embed_ar_preserving=pos_embed_ar_preserving, pos_drop_rate=pos_drop_rate, patch_drop_rate=patch_drop_rate, class_token=class_token, @@ -638,8 +808,9 @@ class VisionTransformerFlex(nn.Module): if fix_init: self.fix_init_weight() - def fix_init_weight(self): - def rescale(param, _layer_id): + def fix_init_weight(self) -> None: + """Apply initialization weight fix with layer-wise scaling.""" + def rescale(param: torch.Tensor, _layer_id: int) -> None: param.div_(math.sqrt(2.0 * _layer_id)) for layer_id, layer in enumerate(self.blocks): @@ -647,6 +818,11 @@ class VisionTransformerFlex(nn.Module): rescale(layer.mlp.fc2.weight.data, layer_id + 1) def init_weights(self, mode: str = '') -> None: + """Initialize model weights according to specified scheme. + + Args: + mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '') + """ assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. named_apply(get_init_weights_vit(mode, head_bias), self) @@ -679,11 +855,24 @@ class VisionTransformerFlex(nn.Module): @torch.jit.ignore def no_weight_decay(self) -> Set: + """Get set of parameter names that should not have weight decay applied. + + Returns: + Set of parameter names to skip during weight decay + """ skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'} return skip_list @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict: + """Get parameter group matcher for optimizer parameter grouping. + + Args: + coarse: Whether to use coarse-grained grouping + + Returns: + Dictionary mapping group names to regex patterns + """ return dict( stem=r'^embeds', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] @@ -691,15 +880,31 @@ class VisionTransformerFlex(nn.Module): @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing for memory efficiency. + + Args: + enable: Whether to enable gradient checkpointing + """ self.grad_checkpointing = enable if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'): self.embeds.patch_embed.set_grad_checkpointing(enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: + """Get the classification head module. + + Returns: + Classification head module + """ return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classification head with new number of classes and pooling. + + Args: + num_classes: Number of classes for new classification head + global_pool: Optional new global pooling type + """ self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') @@ -762,7 +967,7 @@ class VisionTransformerFlex(nn.Module): blocks = self.blocks else: blocks = self.blocks[:max_index + 1] - + for i, blk in enumerate(blocks): x = blk(x, attn_mask=mask) if i in take_indices: @@ -785,7 +990,7 @@ class VisionTransformerFlex(nn.Module): H, W = self.embeds.dynamic_feat_size((height, width)) else: H, W = grid_size - intermediates = [y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() + intermediates = [y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] # For dictionary output @@ -795,14 +1000,14 @@ class VisionTransformerFlex(nn.Module): result_dict['image_intermediates'] = intermediates if prefix_tokens is not None and return_prefix_tokens: result_dict['image_intermediates_prefix'] = prefix_tokens - + # Only include features if not intermediates_only if not intermediates_only: x_final = self.norm(x) result_dict['image_features'] = x_final - + return result_dict - + # For non-dictionary output, maintain the original behavior if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: # return_prefix not support in torchscript due to poor type handling @@ -832,7 +1037,7 @@ class VisionTransformerFlex(nn.Module): # Pass through embedding module with patch coordinate/type support x = self.embeds(x, patch_coord=patch_coord) - + # Apply transformer blocks with masked attention if mask provided if attn_mask is not None: # We need to apply blocks one by one with mask @@ -842,7 +1047,7 @@ class VisionTransformerFlex(nn.Module): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) - + x = self.norm(x) return x @@ -862,7 +1067,7 @@ class VisionTransformerFlex(nn.Module): ) x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask) return x - + pool_type = self.global_pool if pool_type is None else pool_type x = global_pool_naflex( @@ -891,12 +1096,12 @@ class VisionTransformerFlex(nn.Module): patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with optional NaFlex support. - + Args: x: Input tensor [B, C, H, W] or pre-patchified tensor [B, N, P*P*C] or NaFlex dict patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex - + Returns: Model output tensor """ @@ -944,7 +1149,7 @@ class VisionTransformerFlex(nn.Module): def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: """Function imported from vision_transformer.py to maintain compatibility""" from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm - + if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: @@ -953,7 +1158,7 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: return init_weights_vit_timm -def checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict: Dict[str, Any], model: VisionTransformerFlex) -> Dict[str, Any]: """Handle state dict conversion from original ViT to the new version with combined embedding.""" from .vision_transformer import checkpoint_filter_fn as orig_filter_fn @@ -1049,7 +1254,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: default_cfgs = generate_default_cfgs({ 'vit_naflex_base_patch16_gap': _cfg(), 'vit_naflex_base_patch16_map': _cfg(), - 'vit_naflex_so400m_patch16_map': _cfg(), + + 'vit_naflex_base_patch16_siglip': _cfg(), + 'vit_naflex_so400m_patch16_siglip': _cfg(), # sbb model testijg 'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg( @@ -1073,7 +1280,7 @@ default_cfgs = generate_default_cfgs({ }) -def _create_vision_transformer_flex(variant, pretrained=False, **kwargs): +def _create_vision_transformer_flex(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformerFlex: model = build_model_with_cfg( VisionTransformerFlex, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, @@ -1083,19 +1290,19 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs): @register_model -def vit_naflex_base_patch16_gap(pretrained=False, **kwargs): +def vit_naflex_base_patch16_gap(pretrained: bool = False, **kwargs) -> VisionTransformerFlex: """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, - global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs) + global_pool='avg', reg_tokens=4, fc_norm=True, **kwargs) model = _create_vision_transformer_flex( 'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_naflex_base_patch16_map(pretrained=False, **kwargs): +def vit_naflex_base_patch16_map(pretrained: bool = False, **kwargs) -> VisionTransformerFlex: """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ model_args = dict( @@ -1107,7 +1314,7 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs): @register_model -def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs): +def vit_naflex_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> VisionTransformerFlex: """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. This model supports: @@ -1117,31 +1324,43 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs): """ model_args = dict( patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, - qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True) + qkv_bias=False, reg_tokens=1, global_pool='avg', fc_norm=True) model = _create_vision_transformer_flex( 'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_naflex_base_patch16(pretrained: bool = False, **kwargs): +def vit_naflex_base_patch16(pretrained: bool = False, **kwargs) -> VisionTransformerFlex: """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14)) - model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_vision_transformer_flex( + 'vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs): +def vit_naflex_base_patch16_siglip(pretrained: bool = False, **kwargs) -> VisionTransformerFlex: + """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, act_layer='gelu_tanh', global_pool='map') + model = _create_vision_transformer_flex( + 'vit_naflex_base_patch16_siglip', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_naflex_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> VisionTransformerFlex: """ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions. """ model_args = dict( - patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5, - global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh') + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + act_layer='gelu_tanh', global_pool='map') model = _create_vision_transformer_flex( - 'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_naflex_so400m_patch16_siglip', pretrained=pretrained, **dict(model_args, **kwargs)) return model