diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py
index 3398dc37..9b509dab 100644
--- a/timm/models/vision_transformer_flex.py
+++ b/timm/models/vision_transformer_flex.py
@@ -424,6 +424,53 @@ def create_attention_mask(
     return mask_float
 
 
+@register_notrace_function
+def global_pool_naflex(
+        x: torch.Tensor,
+        patch_valid: Optional[torch.Tensor] = None,
+        pool_type: str = 'token',
+        num_prefix_tokens: int = 1,
+):
+    if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
+        # Fall back to standard pooling
+        x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens)
+        return x
+
+    # For NaFlex mode, we need to apply masked pooling to exclude padding tokens
+    # Extract only the patch part of the mask (excluding prefix tokens)
+    if num_prefix_tokens > 0:
+        # Apply the mask to extract only valid tokens
+        x = x[:, num_prefix_tokens:]  # prefix tokens not included in pooling
+
+    patch_valid_float = patch_valid.to(x.dtype)
+    if pool_type == 'avg':
+        # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
+        masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
+        valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
+        pooled = masked_sums / valid_counts
+        return pooled
+    elif pool_type == 'avgmax':
+        # For avgmax, compute masked average and masked max
+        masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
+        valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
+        masked_avg = masked_sums / valid_counts
+
+        # For max pooling we set masked positions to large negative value
+        masked_x = x.clone()
+        masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
+        masked_max = masked_x.amax(dim=1)
+
+        # Combine average and max
+        return 0.5 * (masked_avg + masked_max)
+    elif pool_type == 'max':
+        # For max pooling we set masked positions to large negative value
+        masked_x = x.clone()
+        masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
+        return masked_x.amax(dim=1)
+    else:
+        assert False
+
+
 class VisionTransformerFlex(nn.Module):
     """ Vision Transformer (Na)Flex
 
@@ -817,38 +864,13 @@ class VisionTransformerFlex(nn.Module):
             return x
         
         pool_type = self.global_pool if pool_type is None else pool_type
-        
-        # Handle padding mask for average pooling
-        if patch_valid is not None and pool_type in ('avg', 'avgmax'):
-            # For NaFlex mode, we need to apply masked pooling to exclude padding tokens
-            # Extract only the patch part of the mask (excluding prefix tokens)
-            if self.num_prefix_tokens > 0:
-                # Apply the mask to extract only valid tokens
-                x = x[:, self.num_prefix_tokens:]  # prefix tokens not included in pooling
 
-            patch_valid_float = patch_valid.to(x.dtype)
-            if pool_type == 'avg':
-                # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
-                masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
-                valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
-                pooled = masked_sums / valid_counts
-                return pooled
-            elif pool_type == 'avgmax':
-                # For avgmax, compute masked average and masked max
-                masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
-                valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
-                masked_avg = masked_sums / valid_counts
-
-                # For max pooling we set masked positions to large negative value
-                masked_x = x.clone()
-                masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
-                masked_max = masked_x.max(dim=1)[0]
-
-                # Combine average and max
-                return 0.5 * (masked_avg + masked_max)
-
-        # Fall back to standard pooling
-        x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
+        x = global_pool_naflex(
+            x,
+            patch_valid,
+            pool_type=pool_type,
+            num_prefix_tokens=self.num_prefix_tokens,
+        )
         return x
 
     def forward_head(
@@ -897,14 +919,11 @@ class VisionTransformerFlex(nn.Module):
             patches = x
 
         # Create attention mask if patch_type is provided
-        if patch_valid is not None:
-            attn_mask = create_attention_mask(
-                patch_valid,
-                num_prefix_tokens=self.num_prefix_tokens,
-                dtype=patches.dtype
-            )
-        else:
-            attn_mask = None
+        attn_mask = create_attention_mask(
+            patch_valid,
+            num_prefix_tokens=self.num_prefix_tokens,
+            dtype=patches.dtype,
+        )
 
         # Forward features with mask
         x = self.forward_features(