mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add siglip2 compatible naflex encoders. Add support to factorized pos embeds and 'aspect preserving mode' to Flex Embeds. Some more docstrings and typing.
This commit is contained in:
parent
b7ced7c40c
commit
72858c193c
@ -749,11 +749,9 @@ def patchify_image(
|
||||
|
||||
# Ensure the image is divisible by patch size
|
||||
if pad and (h % ph != 0 or w % pw != 0):
|
||||
new_h = math.ceil(h / ph) * ph
|
||||
new_w = math.ceil(w / pw) * pw
|
||||
padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype)
|
||||
padded_img[:, :h, :w] = img
|
||||
img = padded_img
|
||||
pad_h = (ph - h % ph) % ph # amount to add on bottom
|
||||
pad_w = (pw - w % pw) % pw # amount to add on right
|
||||
img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h))
|
||||
c, h, w = img.shape
|
||||
|
||||
# Calculate number of patches in each dimension
|
||||
|
@ -24,7 +24,16 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert
|
||||
from timm.layers import (
|
||||
AttentionPoolLatent,
|
||||
Mlp,
|
||||
to_2tuple,
|
||||
get_act_layer,
|
||||
get_norm_layer,
|
||||
LayerNorm,
|
||||
LayerType,
|
||||
_assert,
|
||||
)
|
||||
from timm.models._builder import build_model_with_cfg
|
||||
from timm.models._features import feature_take_indices
|
||||
from timm.models._features_fx import register_notrace_function, register_notrace_module
|
||||
@ -59,27 +68,46 @@ def batch_patchify(
|
||||
|
||||
@register_notrace_module
|
||||
class FlexEmbeds(nn.Module):
|
||||
""" Na(Flex) Embedding module for Vision Transformers
|
||||
"""NaFlex Embedding module for Vision Transformers.
|
||||
|
||||
This module encapsulates the complete embedding process for Vision Transformers,
|
||||
supporting both standard and NaFlex (NaViT + FlexiViT) functionality:
|
||||
|
||||
This module encapsulates the complete embedding process for Vision Transformers:
|
||||
1. Patch embedding (via Conv2d or Linear)
|
||||
2. Class and register token preparation
|
||||
3. Position embedding addition
|
||||
3. Position embedding addition with interpolation support
|
||||
4. Pre-normalization (if requested)
|
||||
5. Dropout application
|
||||
|
||||
Also supports NaFlex functionality (NaViT + FlexiViT):
|
||||
NaFlex capabilities include:
|
||||
- Variable aspect ratio and resolution via patch coordinates
|
||||
- Patch type indicators for handling padding tokens in attention
|
||||
- Flexible position embedding interpolation for arbitrary grid sizes
|
||||
|
||||
Note: Only supports non-overlapping position and register tokens
|
||||
(i.e., position embeddings do not include class/register tokens)
|
||||
- Support for factorized position embeddings
|
||||
|
||||
The patch embedding can be one of two types:
|
||||
1. Conv2d-based (default): For standard image inputs [B, C, H, W]
|
||||
2. Linear-based: For pre-patchified inputs [B, N, P*P*C]
|
||||
- Conv2d-based (default): For standard image inputs [B, C, H, W]
|
||||
- Linear-based: For pre-patchified inputs [B, N, P*P*C]
|
||||
|
||||
Args:
|
||||
patch_size: Size of patches for patch embedding
|
||||
in_chans: Number of input image channels
|
||||
embed_dim: Dimensionality of patch embedding
|
||||
embed_layer: Type of embedding layer ('conv' or 'linear')
|
||||
input_norm_layer: Normalization layer applied to input (linear mode only)
|
||||
proj_norm_layer: Normalization layer applied after projection
|
||||
final_norm_layer: Final normalization layer before output
|
||||
pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none')
|
||||
pos_drop_rate: Dropout rate for position embeddings
|
||||
patch_drop_rate: Dropout rate for patch tokens
|
||||
class_token: Whether to include a class token
|
||||
reg_tokens: Number of register tokens to include
|
||||
bias: Whether to use bias in projection layers
|
||||
dynamic_img_pad: Whether to enable dynamic padding for variable resolution
|
||||
pos_embed_grid_size: Grid size for position embedding initialization
|
||||
pos_embed_interp_mode: Interpolation mode for position embedding resizing
|
||||
pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
|
||||
default_img_size: Default image size for position embedding grid calculation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -100,12 +128,14 @@ class FlexEmbeds(nn.Module):
|
||||
dynamic_img_pad: bool = False,
|
||||
pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14),
|
||||
pos_embed_interp_mode: str = 'bicubic',
|
||||
pos_embed_ar_preserving: bool = False,
|
||||
default_img_size: Union[int, Tuple[int, int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_class_token = class_token
|
||||
self.num_reg_tokens = reg_tokens
|
||||
self.pos_embed_interp_mode = pos_embed_interp_mode
|
||||
self.pos_embed_ar_preserving = pos_embed_ar_preserving
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
@ -153,22 +183,28 @@ class FlexEmbeds(nn.Module):
|
||||
self.norm_proj = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity()
|
||||
|
||||
# Create position embedding if needed - only for patches, never for prefix tokens
|
||||
if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None:
|
||||
raise ValueError(
|
||||
"Cannot initialize position embeddings without grid_size."
|
||||
"Please provide img_size or pos_embed_grid_size.")
|
||||
self.pos_embed: Optional[torch.Tensor] = None
|
||||
self.pos_embed_y: Optional[torch.Tensor] = None
|
||||
self.pos_embed_x: Optional[torch.Tensor] = None
|
||||
if not pos_embed or pos_embed == 'none':
|
||||
self.pos_embed = None
|
||||
self.pos_embed_type = 'none'
|
||||
elif pos_embed == 'rope':
|
||||
self.pos_embed = None
|
||||
self.pos_embed_type = 'rope'
|
||||
# Rotary embeddings will be computed on-the-fly in the forward pass
|
||||
elif pos_embed == 'factorized':
|
||||
h, w = self.pos_embed_grid_size
|
||||
self.pos_embed_type = 'factorized'
|
||||
self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim) * .02)
|
||||
self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim) * .02)
|
||||
else:
|
||||
# Store position embedding in (1, H, W, dim) format for easier resizing
|
||||
if self.pos_embed_grid_size is not None:
|
||||
h, w = self.pos_embed_grid_size
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02)
|
||||
self.pos_embed_type = 'learned'
|
||||
else:
|
||||
raise ValueError("Cannot initialize position embeddings without grid_size. "
|
||||
"Please provide img_size or pos_embed_grid_size")
|
||||
|
||||
# Pre-normalization layer (separate from the main norm layer)
|
||||
self.norm_final = final_norm_layer(embed_dim) if final_norm_layer is not None else nn.Identity()
|
||||
@ -184,35 +220,57 @@ class FlexEmbeds(nn.Module):
|
||||
else:
|
||||
self.patch_drop = nn.Identity()
|
||||
|
||||
def feature_info(self, location):
|
||||
"""Feature info utility method for feature extraction."""
|
||||
def feature_info(self, location) -> Dict[str, Any]:
|
||||
"""Get feature information for feature extraction.
|
||||
|
||||
Args:
|
||||
location: Feature extraction location identifier
|
||||
|
||||
Returns:
|
||||
Dictionary containing feature channel count and reduction factor
|
||||
"""
|
||||
return dict(num_chs=self.embed_dim, reduction=self.patch_size)
|
||||
|
||||
def feat_ratio(self, as_scalar=True):
|
||||
"""Return the feature reduction ratio (stride)."""
|
||||
def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]:
|
||||
"""Get the feature reduction ratio (stride) of the patch embedding.
|
||||
|
||||
Args:
|
||||
as_scalar: Whether to return the maximum dimension as a scalar
|
||||
|
||||
Returns:
|
||||
Feature reduction ratio as scalar or tuple
|
||||
"""
|
||||
if as_scalar:
|
||||
return max(self.patch_size)
|
||||
else:
|
||||
return self.patch_size
|
||||
|
||||
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||
""" Get grid (feature) size for given image size taking account of dynamic padding.
|
||||
"""Calculate grid (feature) size for given image size.
|
||||
|
||||
Takes into account dynamic padding when enabled.
|
||||
|
||||
Args:
|
||||
img_size: Input image size as (height, width)
|
||||
|
||||
Returns:
|
||||
Grid size as (grid_height, grid_width)
|
||||
"""
|
||||
if self.dynamic_img_pad:
|
||||
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
|
||||
else:
|
||||
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
|
||||
|
||||
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
|
||||
"""Forward pass for combined embedding
|
||||
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Forward pass for patch embedding with position encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C]
|
||||
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex
|
||||
x: Input tensor [B, C, H, W] for conv mode or [B, N, P*P*C] for pre-patchified linear mode
|
||||
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
|
||||
|
||||
Returns:
|
||||
Embedded tensor with position encoding and class/register tokens applied
|
||||
If patch_type is provided, also returns attention mask
|
||||
Embedded tensor with position encoding and class/register tokens applied.
|
||||
Shape: [B, num_prefix_tokens + N, embed_dim]
|
||||
"""
|
||||
# Apply patch embedding
|
||||
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
|
||||
@ -259,6 +317,9 @@ class FlexEmbeds(nn.Module):
|
||||
else:
|
||||
assert grid_size is not None
|
||||
self._apply_learned_pos_embed(x, grid_size=grid_size)
|
||||
elif self.pos_embed_type == 'factorized':
|
||||
if naflex_grid_sizes is not None:
|
||||
self._apply_factorized_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
|
||||
elif self.pos_embed_type == 'rope':
|
||||
assert False, "ROPE not yet implemented"
|
||||
|
||||
@ -285,22 +346,38 @@ class FlexEmbeds(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
naflex_grid_sizes: List[Tuple[int, int]],
|
||||
):
|
||||
) -> None:
|
||||
"""Apply learned position embeddings to NaFlex batch in-place.
|
||||
|
||||
Interpolates learned position embeddings for each sample in the batch
|
||||
based on their individual grid sizes.
|
||||
|
||||
Args:
|
||||
x: Input tensor to add position embeddings to
|
||||
naflex_grid_sizes: List of (height, width) grid sizes for each batch element
|
||||
"""
|
||||
# Handle each batch element separately with its own grid size
|
||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
|
||||
|
||||
def _interp(_size):
|
||||
if (_size[0] == orig_h) and (_size[1] == orig_w):
|
||||
def _interp2d(size):
|
||||
"""
|
||||
Return a flattened positional-embedding grid at an arbitrary spatial resolution.
|
||||
|
||||
Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into
|
||||
a (1, H*W, C) sequence that matches the requested size.
|
||||
"""
|
||||
if (size[0] == orig_h) and (size[1] == orig_w):
|
||||
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
||||
else:
|
||||
_interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size
|
||||
pos_embed_flat = F.interpolate(
|
||||
pos_embed_nchw,
|
||||
size=_size,
|
||||
size=_interp_size,
|
||||
mode=self.pos_embed_interp_mode,
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
).flatten(2).transpose(1, 2)
|
||||
)[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2)
|
||||
return pos_embed_flat.to(dtype=x.dtype)
|
||||
|
||||
# FIXME leaving alternative code commented here for now for comparisons
|
||||
@ -315,7 +392,7 @@ class FlexEmbeds(nn.Module):
|
||||
# seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||
# x[i, :seq_len] += pos_embed_flat[0, :seq_len]
|
||||
|
||||
# Determine unique grid sizes
|
||||
# Determine unique grid sizes to avoid duplicate interpolation
|
||||
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
|
||||
for bi, k in enumerate(naflex_grid_sizes):
|
||||
# k = h << 16 | w # FIXME can get jit compat with this
|
||||
@ -324,7 +401,7 @@ class FlexEmbeds(nn.Module):
|
||||
for k, batch_indices in size_to_indices.items():
|
||||
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
||||
# Interpolate only once for this (h, w)
|
||||
pos_embed_flat = _interp(k)
|
||||
pos_embed_flat = _interp2d(k)
|
||||
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||
x[:, :seq_len].index_add_(
|
||||
0,
|
||||
@ -336,7 +413,15 @@ class FlexEmbeds(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_size: List[int],
|
||||
):
|
||||
) -> None:
|
||||
"""Apply learned position embeddings to standard batch in-place.
|
||||
|
||||
Interpolates learned position embeddings to match the specified grid size.
|
||||
|
||||
Args:
|
||||
x: Input tensor to add position embeddings to
|
||||
grid_size: Target grid size as [height, width]
|
||||
"""
|
||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||
if grid_size[0] == orig_h or grid_size[1] == orig_w:
|
||||
# No resize needed, just flatten
|
||||
@ -354,6 +439,63 @@ class FlexEmbeds(nn.Module):
|
||||
|
||||
x.add_(pos_embed_flat)
|
||||
|
||||
def _apply_factorized_naflex_pos_embed(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
naflex_grid_sizes: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""Apply factorized position embeddings to NaFlex batch in-place.
|
||||
|
||||
Uses separate Y and X position embedding tables that are interpolated
|
||||
and combined for each sample's grid size.
|
||||
|
||||
Args:
|
||||
x: Input tensor to add position embeddings to
|
||||
naflex_grid_sizes: List of (height, width) grid sizes for each batch element
|
||||
"""
|
||||
assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample
|
||||
|
||||
# Handle each batch element separately with its own grid size
|
||||
orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
|
||||
|
||||
# bucket samples that share the same (H,W) so we build each grid once
|
||||
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
|
||||
for bi, k in enumerate(naflex_grid_sizes):
|
||||
size_to_indices.setdefault(k, []).append(bi)
|
||||
|
||||
def _interp1d(table: torch.Tensor, length: int) -> torch.Tensor:
|
||||
"""
|
||||
Resample a 1-D positional-embedding table to specified length
|
||||
and return it in (1, L, C) layout, dtype matching x.
|
||||
"""
|
||||
return F.interpolate(
|
||||
table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out)
|
||||
size=length,
|
||||
mode='linear',
|
||||
align_corners=False,
|
||||
).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C)
|
||||
|
||||
for k, batch_indices in size_to_indices.items():
|
||||
target_h, target_w = k
|
||||
if self.pos_embed_ar_preserving:
|
||||
len_y = len_x = max(target_h, target_w)
|
||||
else:
|
||||
len_y, len_x = target_h, target_w
|
||||
|
||||
pe_y = _interp1d(self.pos_embed_y, len_y)[:, :target_h] # (1,H,C)
|
||||
pe_x = _interp1d(self.pos_embed_x, len_x)[:, :target_w] # (1,W,C)
|
||||
|
||||
# Broadcast, add and flatten to sequence layout (row major)
|
||||
pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C)
|
||||
pos = pos.flatten(1, 2)
|
||||
|
||||
seq_len = min(x.shape[1], pos.shape[1])
|
||||
x[:, :seq_len].index_add_(
|
||||
0,
|
||||
torch.as_tensor(batch_indices, device=x.device),
|
||||
pos[:, :seq_len].expand(len(batch_indices), -1, -1)
|
||||
)
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def create_attention_mask(
|
||||
@ -430,7 +572,21 @@ def global_pool_naflex(
|
||||
patch_valid: Optional[torch.Tensor] = None,
|
||||
pool_type: str = 'token',
|
||||
num_prefix_tokens: int = 1,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""Global pooling with NaFlex support for masked tokens.
|
||||
|
||||
Applies global pooling while respecting patch validity masks to exclude
|
||||
padding tokens from pooling operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor with shape [B, N, C]
|
||||
patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens]
|
||||
pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max')
|
||||
num_prefix_tokens: Number of prefix tokens (class/register) to exclude from masking
|
||||
|
||||
Returns:
|
||||
Pooled tensor with shape [B, C]
|
||||
"""
|
||||
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
|
||||
# Fall back to standard pooling
|
||||
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
|
||||
@ -472,12 +628,16 @@ def global_pool_naflex(
|
||||
|
||||
|
||||
class VisionTransformerFlex(nn.Module):
|
||||
""" Vision Transformer (Na)Flex
|
||||
"""Vision Transformer with NaFlex support for flexible input handling.
|
||||
|
||||
A flexible implementation of Vision Transformer with:
|
||||
1. Encapsulated embedding and position encoding in a single module
|
||||
2. Support for linear patch embedding on pre-patchified inputs
|
||||
3. Support for variable sequence length / aspect ratio images (NaFlex)
|
||||
A flexible implementation of Vision Transformer that supports:
|
||||
- Standard image classification with various pooling strategies
|
||||
- NaFlex functionality for variable aspect ratios and resolutions
|
||||
- Linear patch embedding for pre-patchified inputs
|
||||
- Multiple position embedding strategies (learned, factorized, rope)
|
||||
- Comprehensive attention masking for efficient batch processing
|
||||
- Encapsulated embedding and position encoding in FlexEmbeds module
|
||||
- Compatible with standard ViT checkpoints through checkpoint filtering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -494,9 +654,10 @@ class VisionTransformerFlex(nn.Module):
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = False,
|
||||
reg_tokens: int = 0,
|
||||
pos_embed: str = 'learn',
|
||||
pos_embed: str = 'learned',
|
||||
pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16),
|
||||
pos_embed_interp_mode: str = 'bicubic',
|
||||
pos_embed_ar_preserving: bool = False,
|
||||
default_img_size: Union[int, Tuple[int, int]] = 256,
|
||||
dynamic_img_pad: bool = False,
|
||||
pre_norm: bool = False,
|
||||
@ -519,44 +680,52 @@ class VisionTransformerFlex(nn.Module):
|
||||
block_fn: Type[nn.Module] = Block,
|
||||
mlp_layer: Type[nn.Module] = Mlp,
|
||||
) -> None:
|
||||
"""
|
||||
"""Initialize VisionTransformerFlex model.
|
||||
|
||||
Args:
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of image input channels.
|
||||
embed_dim: Transformer embedding dimension.
|
||||
depth: Depth of transformer.
|
||||
num_heads: Number of attention heads.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: Enable bias for qkv projections if True.
|
||||
init_values: Layer-scale init values (layer-scale enabled if not None).
|
||||
class_token: Use class token.
|
||||
reg_tokens: Number of register tokens.
|
||||
pos_embed: Type of position embedding.
|
||||
pos_embed_grid_size: Size of position embedding grid.
|
||||
pos_embed_interp_mode: Interpolation mode for position embedding.
|
||||
default_img_size: Input image size.
|
||||
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
|
||||
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
|
||||
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
||||
num_classes: Number of classes for classification head.
|
||||
global_pool: Type of global pooling for final sequence (default: 'token').
|
||||
drop_rate: Head dropout rate.
|
||||
pos_drop_rate: Position embedding dropout rate.
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
weight_init: Weight initialization scheme.
|
||||
fix_init: Apply weight initialization fix (scaling w/ layer index).
|
||||
embed_layer_type: Patch embedding implementation (e.g., 'linear', 'conv').
|
||||
embed_norm_layer: Normalization layer to use / override in patch embed module.
|
||||
norm_layer: Normalization layer.
|
||||
act_layer: MLP activation layer.
|
||||
block_fn: Transformer block layer.
|
||||
patch_size: Size of patches for patch embedding
|
||||
in_chans: Number of input image channels
|
||||
embed_dim: Transformer embedding dimension
|
||||
depth: Number of transformer blocks
|
||||
num_heads: Number of attention heads
|
||||
mlp_ratio: Ratio of MLP hidden dimension to embedding dimension
|
||||
qkv_bias: Whether to use bias in query, key, value projections
|
||||
qk_norm: Whether to normalize query and key projections
|
||||
proj_bias: Whether to use bias in linear projections
|
||||
init_values: Layer-scale init values (layer-scale enabled if not None)
|
||||
class_token: Whether to include a class token
|
||||
reg_tokens: Number of register tokens to include
|
||||
pos_embed: Type of position embedding
|
||||
pos_embed_grid_size: Grid size for position embedding initialization
|
||||
pos_embed_interp_mode: Interpolation mode for position embedding resizing
|
||||
pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
|
||||
default_img_size: Default image size for position embedding grid calculation
|
||||
dynamic_img_pad: Whether to enable dynamic padding for variable resolution
|
||||
pre_norm: Whether to apply normalization before attention/MLP layers
|
||||
final_norm: Whether to apply final normalization before classifier
|
||||
fc_norm: Whether to normalize features before final classifier
|
||||
num_classes: Number of classification classes
|
||||
global_pool: Type of global pooling for final sequence
|
||||
drop_rate: Dropout rate for classifier
|
||||
pos_drop_rate: Dropout rate for position embeddings
|
||||
patch_drop_rate: Dropout rate for patch tokens
|
||||
proj_drop_rate: Dropout rate for linear projections
|
||||
attn_drop_rate: Dropout rate for attention weights
|
||||
drop_path_rate: Stochastic depth drop rate
|
||||
weight_init: Weight initialization scheme
|
||||
fix_init: Apply weight initialization fix (scaling w/ layer index)
|
||||
embed_layer_type: Type of embedding layer ('conv' or 'linear')
|
||||
embed_norm_layer: Normalization layer for embeddings
|
||||
norm_layer: Normalization layer for transformer blocks
|
||||
act_layer: Activation layer for MLP blocks
|
||||
block_fn: Transformer block implementation class
|
||||
mlp_layer: MLP implementation class
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
||||
assert class_token or global_pool != 'token'
|
||||
assert pos_embed in ('', 'none', 'learn')
|
||||
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||
assert pos_embed in ('', 'none', 'learned', 'factorized')
|
||||
norm_layer = get_norm_layer(norm_layer) or LayerNorm
|
||||
embed_norm_layer = get_norm_layer(embed_norm_layer)
|
||||
act_layer = get_act_layer(act_layer) or nn.GELU
|
||||
|
||||
@ -581,6 +750,7 @@ class VisionTransformerFlex(nn.Module):
|
||||
pos_embed=pos_embed,
|
||||
pos_embed_grid_size=pos_embed_grid_size,
|
||||
pos_embed_interp_mode=pos_embed_interp_mode,
|
||||
pos_embed_ar_preserving=pos_embed_ar_preserving,
|
||||
pos_drop_rate=pos_drop_rate,
|
||||
patch_drop_rate=patch_drop_rate,
|
||||
class_token=class_token,
|
||||
@ -638,8 +808,9 @@ class VisionTransformerFlex(nn.Module):
|
||||
if fix_init:
|
||||
self.fix_init_weight()
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, _layer_id):
|
||||
def fix_init_weight(self) -> None:
|
||||
"""Apply initialization weight fix with layer-wise scaling."""
|
||||
def rescale(param: torch.Tensor, _layer_id: int) -> None:
|
||||
param.div_(math.sqrt(2.0 * _layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
@ -647,6 +818,11 @@ class VisionTransformerFlex(nn.Module):
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def init_weights(self, mode: str = '') -> None:
|
||||
"""Initialize model weights according to specified scheme.
|
||||
|
||||
Args:
|
||||
mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '')
|
||||
"""
|
||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
named_apply(get_init_weights_vit(mode, head_bias), self)
|
||||
@ -679,11 +855,24 @@ class VisionTransformerFlex(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self) -> Set:
|
||||
"""Get set of parameter names that should not have weight decay applied.
|
||||
|
||||
Returns:
|
||||
Set of parameter names to skip during weight decay
|
||||
"""
|
||||
skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'}
|
||||
return skip_list
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||
"""Get parameter group matcher for optimizer parameter grouping.
|
||||
|
||||
Args:
|
||||
coarse: Whether to use coarse-grained grouping
|
||||
|
||||
Returns:
|
||||
Dictionary mapping group names to regex patterns
|
||||
"""
|
||||
return dict(
|
||||
stem=r'^embeds', # stem and embed
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||
@ -691,15 +880,31 @@ class VisionTransformerFlex(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||
"""Enable or disable gradient checkpointing for memory efficiency.
|
||||
|
||||
Args:
|
||||
enable: Whether to enable gradient checkpointing
|
||||
"""
|
||||
self.grad_checkpointing = enable
|
||||
if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'):
|
||||
self.embeds.patch_embed.set_grad_checkpointing(enable)
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self) -> nn.Module:
|
||||
"""Get the classification head module.
|
||||
|
||||
Returns:
|
||||
Classification head module
|
||||
"""
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
|
||||
"""Reset the classification head with new number of classes and pooling.
|
||||
|
||||
Args:
|
||||
num_classes: Number of classes for new classification head
|
||||
global_pool: Optional new global pooling type
|
||||
"""
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
||||
@ -953,7 +1158,7 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
|
||||
return init_weights_vit_timm
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
def checkpoint_filter_fn(state_dict: Dict[str, Any], model: VisionTransformerFlex) -> Dict[str, Any]:
|
||||
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
|
||||
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
|
||||
|
||||
@ -1049,7 +1254,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'vit_naflex_base_patch16_gap': _cfg(),
|
||||
'vit_naflex_base_patch16_map': _cfg(),
|
||||
'vit_naflex_so400m_patch16_map': _cfg(),
|
||||
|
||||
'vit_naflex_base_patch16_siglip': _cfg(),
|
||||
'vit_naflex_so400m_patch16_siglip': _cfg(),
|
||||
|
||||
# sbb model testijg
|
||||
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
|
||||
@ -1073,7 +1280,7 @@ default_cfgs = generate_default_cfgs({
|
||||
})
|
||||
|
||||
|
||||
def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
|
||||
def _create_vision_transformer_flex(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
|
||||
model = build_model_with_cfg(
|
||||
VisionTransformerFlex, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
@ -1083,19 +1290,19 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs):
|
||||
def vit_naflex_base_patch16_gap(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
|
||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
|
||||
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
|
||||
global_pool='avg', reg_tokens=4, fc_norm=True, **kwargs)
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
|
||||
def vit_naflex_base_patch16_map(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
|
||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
"""
|
||||
model_args = dict(
|
||||
@ -1107,7 +1314,7 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
|
||||
def vit_naflex_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
|
||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
|
||||
This model supports:
|
||||
@ -1117,31 +1324,43 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
|
||||
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True)
|
||||
qkv_bias=False, reg_tokens=1, global_pool='avg', fc_norm=True)
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
|
||||
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
|
||||
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs):
|
||||
def vit_naflex_base_patch16_siglip(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
|
||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, act_layer='gelu_tanh', global_pool='map')
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_base_patch16_siglip', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> VisionTransformerFlex:
|
||||
"""ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5,
|
||||
global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh')
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||
act_layer='gelu_tanh', global_pool='map')
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_naflex_so400m_patch16_siglip', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user