Move naflex global pool into one fn that can be marked notrace

This commit is contained in:
Ross Wightman 2025-05-24 14:06:12 -07:00
parent 2ad75e8023
commit 162f49295e

View File

@ -424,6 +424,53 @@ def create_attention_mask(
return mask_float
@register_notrace_function
def global_pool_naflex(
x: torch.Tensor,
patch_valid: Optional[torch.Tensor] = None,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
):
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
# Fall back to standard pooling
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
return x
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
# Extract only the patch part of the mask (excluding prefix tokens)
if num_prefix_tokens > 0:
# Apply the mask to extract only valid tokens
x = x[:, num_prefix_tokens:] # prefix tokens not included in pooling
patch_valid_float = patch_valid.to(x.dtype)
if pool_type == 'avg':
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
pooled = masked_sums / valid_counts
return pooled
elif pool_type == 'avgmax':
# For avgmax, compute masked average and masked max
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
masked_avg = masked_sums / valid_counts
# For max pooling we set masked positions to large negative value
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
masked_max = masked_x.amax(dim=1)
# Combine average and max
return 0.5 * (masked_avg + masked_max)
elif pool_type == 'max':
# For max pooling we set masked positions to large negative value
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
return masked_x.amax(dim=1)
else:
assert False
class VisionTransformerFlex(nn.Module):
""" Vision Transformer (Na)Flex
@ -817,38 +864,13 @@ class VisionTransformerFlex(nn.Module):
return x
pool_type = self.global_pool if pool_type is None else pool_type
# Handle padding mask for average pooling
if patch_valid is not None and pool_type in ('avg', 'avgmax'):
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
# Extract only the patch part of the mask (excluding prefix tokens)
if self.num_prefix_tokens > 0:
# Apply the mask to extract only valid tokens
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
patch_valid_float = patch_valid.to(x.dtype)
if pool_type == 'avg':
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
pooled = masked_sums / valid_counts
return pooled
elif pool_type == 'avgmax':
# For avgmax, compute masked average and masked max
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
masked_avg = masked_sums / valid_counts
# For max pooling we set masked positions to large negative value
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
masked_max = masked_x.max(dim=1)[0]
# Combine average and max
return 0.5 * (masked_avg + masked_max)
# Fall back to standard pooling
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
x = global_pool_naflex(
x,
patch_valid,
pool_type=pool_type,
num_prefix_tokens=self.num_prefix_tokens,
)
return x
def forward_head(
@ -897,14 +919,11 @@ class VisionTransformerFlex(nn.Module):
patches = x
# Create attention mask if patch_type is provided
if patch_valid is not None:
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,
dtype=patches.dtype
)
else:
attn_mask = None
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,
dtype=patches.dtype,
)
# Forward features with mask
x = self.forward_features(