mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Move naflex global pool into one fn that can be marked notrace
This commit is contained in:
parent
2ad75e8023
commit
162f49295e
@ -424,6 +424,53 @@ def create_attention_mask(
|
|||||||
return mask_float
|
return mask_float
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_function
|
||||||
|
def global_pool_naflex(
|
||||||
|
x: torch.Tensor,
|
||||||
|
patch_valid: Optional[torch.Tensor] = None,
|
||||||
|
pool_type: str = 'token',
|
||||||
|
num_prefix_tokens: int = 1,
|
||||||
|
):
|
||||||
|
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
|
||||||
|
# Fall back to standard pooling
|
||||||
|
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
|
||||||
|
# Extract only the patch part of the mask (excluding prefix tokens)
|
||||||
|
if num_prefix_tokens > 0:
|
||||||
|
# Apply the mask to extract only valid tokens
|
||||||
|
x = x[:, num_prefix_tokens:] # prefix tokens not included in pooling
|
||||||
|
|
||||||
|
patch_valid_float = patch_valid.to(x.dtype)
|
||||||
|
if pool_type == 'avg':
|
||||||
|
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
|
||||||
|
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
||||||
|
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
||||||
|
pooled = masked_sums / valid_counts
|
||||||
|
return pooled
|
||||||
|
elif pool_type == 'avgmax':
|
||||||
|
# For avgmax, compute masked average and masked max
|
||||||
|
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
||||||
|
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
||||||
|
masked_avg = masked_sums / valid_counts
|
||||||
|
|
||||||
|
# For max pooling we set masked positions to large negative value
|
||||||
|
masked_x = x.clone()
|
||||||
|
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||||
|
masked_max = masked_x.amax(dim=1)
|
||||||
|
|
||||||
|
# Combine average and max
|
||||||
|
return 0.5 * (masked_avg + masked_max)
|
||||||
|
elif pool_type == 'max':
|
||||||
|
# For max pooling we set masked positions to large negative value
|
||||||
|
masked_x = x.clone()
|
||||||
|
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||||
|
return masked_x.amax(dim=1)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformerFlex(nn.Module):
|
class VisionTransformerFlex(nn.Module):
|
||||||
""" Vision Transformer (Na)Flex
|
""" Vision Transformer (Na)Flex
|
||||||
|
|
||||||
@ -817,38 +864,13 @@ class VisionTransformerFlex(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
pool_type = self.global_pool if pool_type is None else pool_type
|
pool_type = self.global_pool if pool_type is None else pool_type
|
||||||
|
|
||||||
# Handle padding mask for average pooling
|
|
||||||
if patch_valid is not None and pool_type in ('avg', 'avgmax'):
|
|
||||||
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
|
|
||||||
# Extract only the patch part of the mask (excluding prefix tokens)
|
|
||||||
if self.num_prefix_tokens > 0:
|
|
||||||
# Apply the mask to extract only valid tokens
|
|
||||||
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
|
|
||||||
|
|
||||||
patch_valid_float = patch_valid.to(x.dtype)
|
x = global_pool_naflex(
|
||||||
if pool_type == 'avg':
|
x,
|
||||||
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
|
patch_valid,
|
||||||
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
pool_type=pool_type,
|
||||||
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
num_prefix_tokens=self.num_prefix_tokens,
|
||||||
pooled = masked_sums / valid_counts
|
)
|
||||||
return pooled
|
|
||||||
elif pool_type == 'avgmax':
|
|
||||||
# For avgmax, compute masked average and masked max
|
|
||||||
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
|
|
||||||
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
|
|
||||||
masked_avg = masked_sums / valid_counts
|
|
||||||
|
|
||||||
# For max pooling we set masked positions to large negative value
|
|
||||||
masked_x = x.clone()
|
|
||||||
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
|
||||||
masked_max = masked_x.max(dim=1)[0]
|
|
||||||
|
|
||||||
# Combine average and max
|
|
||||||
return 0.5 * (masked_avg + masked_max)
|
|
||||||
|
|
||||||
# Fall back to standard pooling
|
|
||||||
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(
|
def forward_head(
|
||||||
@ -897,14 +919,11 @@ class VisionTransformerFlex(nn.Module):
|
|||||||
patches = x
|
patches = x
|
||||||
|
|
||||||
# Create attention mask if patch_type is provided
|
# Create attention mask if patch_type is provided
|
||||||
if patch_valid is not None:
|
attn_mask = create_attention_mask(
|
||||||
attn_mask = create_attention_mask(
|
patch_valid,
|
||||||
patch_valid,
|
num_prefix_tokens=self.num_prefix_tokens,
|
||||||
num_prefix_tokens=self.num_prefix_tokens,
|
dtype=patches.dtype,
|
||||||
dtype=patches.dtype
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_mask = None
|
|
||||||
|
|
||||||
# Forward features with mask
|
# Forward features with mask
|
||||||
x = self.forward_features(
|
x = self.forward_features(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user