diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py index 7ddc3377..93c2c892 100644 --- a/timm/models/vision_transformer_flex.py +++ b/timm/models/vision_transformer_flex.py @@ -297,24 +297,24 @@ class FlexEmbeds(nn.Module): size_to_indices[k].append(bi) # Handle each batch element separately with its own grid size + pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W for k, batch_indices in size_to_indices.items(): h, w = k #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this # Interpolate only once for this (h, w) if (h == orig_h) and (w == orig_w): - pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1) + pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: - pos_embed_resized = F.interpolate( - self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W + pos_embed_flat = F.interpolate( + pos_embed_nchw, size=(h, w), mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, - ) - pos_embed_flat = pos_embed_resized.permute(0, 2, 3, 1).reshape(h * w, -1) + ).flatten(2).transpose(1, 2) - seq_len = min(x.shape[1], pos_embed_flat.shape[0]) - x[batch_indices, :seq_len].add_(pos_embed_flat[:seq_len]) + seq_len = min(x.shape[1], pos_embed_flat.shape[1]) + x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len]) def _apply_learned_pos_embed( self, @@ -322,106 +322,84 @@ class FlexEmbeds(nn.Module): grid_size: List[int], ): orig_h, orig_w = self.pos_embed.shape[1:3] - if grid_size[0] != orig_h or grid_size[1] != orig_w: + if grid_size[0] == orig_h or grid_size[1] == orig_w: + # No resize needed, just flatten + pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) + else: # Resize if needed - directly using F.interpolate - pos_embed = F.interpolate( + pos_embed_flat = F.interpolate( self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W size=grid_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, - ) - # Convert back and flatten - pos_embed = pos_embed.permute(0, 2, 3, 1) - pos_embed = pos_embed.reshape(1, grid_size[0] * grid_size[1], -1) + ).flatten(2).transpose(1, 2) - else: - # No resize needed, just flatten - pos_embed = self.pos_embed.reshape(1, orig_h * orig_w, -1) - - x.add_(pos_embed) + x.add_(pos_embed_flat) @register_notrace_function def create_attention_mask( - patch_valid: torch.Tensor, - num_prefix_tokens: int = 0, - dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - """Create attention mask from patch type information. - - Used for NaFlex mode to handle variable token counts and padding tokens. - - Args: - patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding - num_prefix_tokens: Number of prefix tokens (class token, register tokens) - dtype: Dtype of the attention mask - - Returns: - Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens, - or None if patch_type is None - """ - patch_valid = patch_valid.to(torch.bool) - B = patch_valid.shape[0] - - if num_prefix_tokens > 0: - prefix_valid = patch_valid.new_ones((B, num_prefix_tokens)) - patch_valid = torch.cat([prefix_valid, patch_valid], dim=1) - - mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1) - mask_float = torch.zeros_like(mask_bool, dtype=dtype) - mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min) - - return mask_float - - -@register_notrace_function -def create_attention_mask2( patch_valid: torch.Tensor, num_prefix_tokens: int = 0, + symmetric: bool = True, q_len: Optional[int] = None, dtype: torch.dtype = torch.float32, -) -> Optional[torch.Tensor]: - """Create expanded attention mask from patch validity info. +) -> torch.Tensor: + """Creates an attention mask from patch validity information. + + Supports two modes controlled by `symmetric`: + 1. `symmetric=True` (default): Creates a symmetric mask of shape + [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if + both token i and token j are valid. Suitable for standard self-attention. + 2. `symmetric=False`: Creates a potentially non-square mask of shape + [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if + the key/value token k is valid. Query token validity is not checked + in the mask itself. Useful for cross-attention or specific self-attention + implementations `q_len` can be specified. Used for NaFlex mode to handle variable token counts and padding tokens. Args: - patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding + patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding. num_prefix_tokens: Number of prefix tokens (class token, register tokens) - q_len: Length override for query sequence - dtype: Dtype of the attention mask + to prepend, which are always considered valid. + symmetric: If True, create a symmetric mask. + If False, create an expanded mask based only on key/value validity. + q_len: Query sequence length override. Only used when `symmetric` is False. + Defaults to the key/value sequence length (`kv_len`) if None. + dtype: Dtype of the output attention mask (e.g., torch.float32). Returns: - Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens, - or None if patch_type is None + Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked). + Shape is [B, 1, seq_len, seq_len] if symmetric=True, + or [B, 1, q_len, kv_len] if symmetric=False. """ - patch_valid = patch_valid.bool() - B, kv_len = patch_valid.shape + patch_valid = patch_valid.bool() # Ensure boolean type + B, N = patch_valid.shape + kv_len = N # Initial key/value length is the number of patches + # Prepend prefix tokens if any if num_prefix_tokens > 0: - prefix_valid = patch_valid.new_ones((B, num_prefix_tokens)) + # Create prefix validity tensor on the same device/dtype base as patch_valid + prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool) + # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N] patch_valid = torch.cat([prefix_valid, patch_valid], dim=1) - kv_len = patch_valid.shape[1] + kv_len += num_prefix_tokens # Update total key/value sequence length - q_len = q_len if q_len is not None else kv_len + if symmetric: + # Symmetric mask is True where BOTH query and key are valid + mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1) + mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len] + else: + # Expanded mask + q_len = q_len or kv_len + mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len) - mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len).to(dtype) + # Create the float mask and apply masking using additive mask convention mask_float = torch.zeros_like(mask_bool, dtype=dtype) - mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min) - - return mask_float - - -@register_notrace_function -def create_pool_mask( - patch_valid:torch.Tensor, - dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - patch_valid = patch_valid.bool() - mask_bool = patch_valid[:, None, None, :] - mask_float = torch.zeros_like(mask_bool, dtype=dtype) - mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min) + # Fill with negative infinity where mask_bool is False (masked positions) + mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min) return mask_float @@ -809,7 +787,12 @@ class VisionTransformerFlex(nn.Module): ) -> torch.Tensor: if self.attn_pool is not None: # For attention pooling, we need to pass the mask for NaFlex models - attn_mask = create_pool_mask(patch_valid, dtype=x.dtype) + attn_mask = create_attention_mask( + patch_valid, + symmetric=False, + q_len=1, + dtype=x.dtype, + ) x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask) return x @@ -839,7 +822,7 @@ class VisionTransformerFlex(nn.Module): # For max pooling with mask masked_x = x.clone() - masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min + masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min masked_max = masked_x.max(dim=1)[0] # Combine average and max @@ -876,9 +859,7 @@ class VisionTransformerFlex(nn.Module): Returns: Model output tensor """ - if isinstance(x, torch.Tensor): - patches = x - else: + if isinstance(x, Dict): # Handle dictionary input from NaFlex collator patch_coord = x['patch_coord'] patch_valid = x['patch_valid'] @@ -893,6 +874,8 @@ class VisionTransformerFlex(nn.Module): # patch = patch.reshape(3, h*16, w*16) # from torchvision.utils import save_image # save_image(patch, f'patch_{i}.jpg', normalize=True) + else: + patches = x # Create attention mask if patch_type is provided if patch_valid is not None: