mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Move naflex global pool into one fn that can be marked notrace
This commit is contained in:
parent
2ad75e8023
commit
162f49295e
@ -424,6 +424,53 @@ def create_attention_mask(
|
||||
return mask_float
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def global_pool_naflex(
|
||||
x: torch.Tensor,
|
||||
patch_valid: Optional[torch.Tensor] = None,
|
||||
pool_type: str = 'token',
|
||||
num_prefix_tokens: int = 1,
|
||||
):
|
||||
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
|
||||
# Fall back to standard pooling
|
||||
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
|
||||
return x
|
||||
|
||||
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
|
||||
# Extract only the patch part of the mask (excluding prefix tokens)
|
||||
if num_prefix_tokens > 0:
|
||||
# Apply the mask to extract only valid tokens
|
||||
x = x[:, num_prefix_tokens:] # prefix tokens not included in pooling
|
||||
|
||||
patch_valid_float = patch_valid.to(x.dtype)
|
||||
if pool_type == 'avg':
|
||||
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
|
||||
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
||||
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = masked_sums / valid_counts
|
||||
return pooled
|
||||
elif pool_type == 'avgmax':
|
||||
# For avgmax, compute masked average and masked max
|
||||
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
||||
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
masked_avg = masked_sums / valid_counts
|
||||
|
||||
# For max pooling we set masked positions to large negative value
|
||||
masked_x = x.clone()
|
||||
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||
masked_max = masked_x.amax(dim=1)
|
||||
|
||||
# Combine average and max
|
||||
return 0.5 * (masked_avg + masked_max)
|
||||
elif pool_type == 'max':
|
||||
# For max pooling we set masked positions to large negative value
|
||||
masked_x = x.clone()
|
||||
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||
return masked_x.amax(dim=1)
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class VisionTransformerFlex(nn.Module):
|
||||
""" Vision Transformer (Na)Flex
|
||||
|
||||
@ -817,38 +864,13 @@ class VisionTransformerFlex(nn.Module):
|
||||
return x
|
||||
|
||||
pool_type = self.global_pool if pool_type is None else pool_type
|
||||
|
||||
# Handle padding mask for average pooling
|
||||
if patch_valid is not None and pool_type in ('avg', 'avgmax'):
|
||||
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
|
||||
# Extract only the patch part of the mask (excluding prefix tokens)
|
||||
if self.num_prefix_tokens > 0:
|
||||
# Apply the mask to extract only valid tokens
|
||||
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
|
||||
|
||||
patch_valid_float = patch_valid.to(x.dtype)
|
||||
if pool_type == 'avg':
|
||||
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
|
||||
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
||||
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = masked_sums / valid_counts
|
||||
return pooled
|
||||
elif pool_type == 'avgmax':
|
||||
# For avgmax, compute masked average and masked max
|
||||
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
||||
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
masked_avg = masked_sums / valid_counts
|
||||
|
||||
# For max pooling we set masked positions to large negative value
|
||||
masked_x = x.clone()
|
||||
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||
masked_max = masked_x.max(dim=1)[0]
|
||||
|
||||
# Combine average and max
|
||||
return 0.5 * (masked_avg + masked_max)
|
||||
|
||||
# Fall back to standard pooling
|
||||
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
|
||||
x = global_pool_naflex(
|
||||
x,
|
||||
patch_valid,
|
||||
pool_type=pool_type,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward_head(
|
||||
@ -897,14 +919,11 @@ class VisionTransformerFlex(nn.Module):
|
||||
patches = x
|
||||
|
||||
# Create attention mask if patch_type is provided
|
||||
if patch_valid is not None:
|
||||
attn_mask = create_attention_mask(
|
||||
patch_valid,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
dtype=patches.dtype
|
||||
)
|
||||
else:
|
||||
attn_mask = None
|
||||
attn_mask = create_attention_mask(
|
||||
patch_valid,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
dtype=patches.dtype,
|
||||
)
|
||||
|
||||
# Forward features with mask
|
||||
x = self.forward_features(
|
||||
|
Loading…
x
Reference in New Issue
Block a user