diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py index 3398dc37..9b509dab 100644 --- a/timm/models/vision_transformer_flex.py +++ b/timm/models/vision_transformer_flex.py @@ -424,6 +424,53 @@ def create_attention_mask( return mask_float +@register_notrace_function +def global_pool_naflex( + x: torch.Tensor, + patch_valid: Optional[torch.Tensor] = None, + pool_type: str = 'token', + num_prefix_tokens: int = 1, +): + if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'): + # Fall back to standard pooling + x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens) + return x + + # For NaFlex mode, we need to apply masked pooling to exclude padding tokens + # Extract only the patch part of the mask (excluding prefix tokens) + if num_prefix_tokens > 0: + # Apply the mask to extract only valid tokens + x = x[:, num_prefix_tokens:] # prefix tokens not included in pooling + + patch_valid_float = patch_valid.to(x.dtype) + if pool_type == 'avg': + # Compute masked average pooling, sum valid tokens and divide by count of valid tokens + masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) + valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) + pooled = masked_sums / valid_counts + return pooled + elif pool_type == 'avgmax': + # For avgmax, compute masked average and masked max + masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) + valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) + masked_avg = masked_sums / valid_counts + + # For max pooling we set masked positions to large negative value + masked_x = x.clone() + masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min + masked_max = masked_x.amax(dim=1) + + # Combine average and max + return 0.5 * (masked_avg + masked_max) + elif pool_type == 'max': + # For max pooling we set masked positions to large negative value + masked_x = x.clone() + masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min + return masked_x.amax(dim=1) + else: + assert False + + class VisionTransformerFlex(nn.Module): """ Vision Transformer (Na)Flex @@ -817,38 +864,13 @@ class VisionTransformerFlex(nn.Module): return x pool_type = self.global_pool if pool_type is None else pool_type - - # Handle padding mask for average pooling - if patch_valid is not None and pool_type in ('avg', 'avgmax'): - # For NaFlex mode, we need to apply masked pooling to exclude padding tokens - # Extract only the patch part of the mask (excluding prefix tokens) - if self.num_prefix_tokens > 0: - # Apply the mask to extract only valid tokens - x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling - patch_valid_float = patch_valid.to(x.dtype) - if pool_type == 'avg': - # Compute masked average pooling, sum valid tokens and divide by count of valid tokens - masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) - valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) - pooled = masked_sums / valid_counts - return pooled - elif pool_type == 'avgmax': - # For avgmax, compute masked average and masked max - masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) - valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) - masked_avg = masked_sums / valid_counts - - # For max pooling we set masked positions to large negative value - masked_x = x.clone() - masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min - masked_max = masked_x.max(dim=1)[0] - - # Combine average and max - return 0.5 * (masked_avg + masked_max) - - # Fall back to standard pooling - x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) + x = global_pool_naflex( + x, + patch_valid, + pool_type=pool_type, + num_prefix_tokens=self.num_prefix_tokens, + ) return x def forward_head( @@ -897,14 +919,11 @@ class VisionTransformerFlex(nn.Module): patches = x # Create attention mask if patch_type is provided - if patch_valid is not None: - attn_mask = create_attention_mask( - patch_valid, - num_prefix_tokens=self.num_prefix_tokens, - dtype=patches.dtype - ) - else: - attn_mask = None + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens, + dtype=patches.dtype, + ) # Forward features with mask x = self.forward_features(