Optimizations for pos embed resize, merge different mask helper fns

This commit is contained in:
Ross Wightman 2025-04-21 14:05:18 -07:00
parent ea728f67fa
commit c527c37969

View File

@ -297,24 +297,24 @@ class FlexEmbeds(nn.Module):
size_to_indices[k].append(bi)
# Handle each batch element separately with its own grid size
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W
for k, batch_indices in size_to_indices.items():
h, w = k
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
# Interpolate only once for this (h, w)
if (h == orig_h) and (w == orig_w):
pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1)
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else:
pos_embed_resized = F.interpolate(
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
pos_embed_flat = F.interpolate(
pos_embed_nchw,
size=(h, w),
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
)
pos_embed_flat = pos_embed_resized.permute(0, 2, 3, 1).reshape(h * w, -1)
).flatten(2).transpose(1, 2)
seq_len = min(x.shape[1], pos_embed_flat.shape[0])
x[batch_indices, :seq_len].add_(pos_embed_flat[:seq_len])
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len])
def _apply_learned_pos_embed(
self,
@ -322,106 +322,84 @@ class FlexEmbeds(nn.Module):
grid_size: List[int],
):
orig_h, orig_w = self.pos_embed.shape[1:3]
if grid_size[0] != orig_h or grid_size[1] != orig_w:
if grid_size[0] == orig_h or grid_size[1] == orig_w:
# No resize needed, just flatten
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else:
# Resize if needed - directly using F.interpolate
pos_embed = F.interpolate(
pos_embed_flat = F.interpolate(
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
size=grid_size,
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
)
# Convert back and flatten
pos_embed = pos_embed.permute(0, 2, 3, 1)
pos_embed = pos_embed.reshape(1, grid_size[0] * grid_size[1], -1)
).flatten(2).transpose(1, 2)
else:
# No resize needed, just flatten
pos_embed = self.pos_embed.reshape(1, orig_h * orig_w, -1)
x.add_(pos_embed)
x.add_(pos_embed_flat)
@register_notrace_function
def create_attention_mask(
patch_valid: torch.Tensor,
num_prefix_tokens: int = 0,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Create attention mask from patch type information.
Used for NaFlex mode to handle variable token counts and padding tokens.
Args:
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
dtype: Dtype of the attention mask
Returns:
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
or None if patch_type is None
"""
patch_valid = patch_valid.to(torch.bool)
B = patch_valid.shape[0]
if num_prefix_tokens > 0:
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
return mask_float
@register_notrace_function
def create_attention_mask2(
patch_valid: torch.Tensor,
num_prefix_tokens: int = 0,
symmetric: bool = True,
q_len: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> Optional[torch.Tensor]:
"""Create expanded attention mask from patch validity info.
) -> torch.Tensor:
"""Creates an attention mask from patch validity information.
Supports two modes controlled by `symmetric`:
1. `symmetric=True` (default): Creates a symmetric mask of shape
[B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
both token i and token j are valid. Suitable for standard self-attention.
2. `symmetric=False`: Creates a potentially non-square mask of shape
[B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
the key/value token k is valid. Query token validity is not checked
in the mask itself. Useful for cross-attention or specific self-attention
implementations `q_len` can be specified.
Used for NaFlex mode to handle variable token counts and padding tokens.
Args:
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
q_len: Length override for query sequence
dtype: Dtype of the attention mask
to prepend, which are always considered valid.
symmetric: If True, create a symmetric mask.
If False, create an expanded mask based only on key/value validity.
q_len: Query sequence length override. Only used when `symmetric` is False.
Defaults to the key/value sequence length (`kv_len`) if None.
dtype: Dtype of the output attention mask (e.g., torch.float32).
Returns:
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
or None if patch_type is None
Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
or [B, 1, q_len, kv_len] if symmetric=False.
"""
patch_valid = patch_valid.bool()
B, kv_len = patch_valid.shape
patch_valid = patch_valid.bool() # Ensure boolean type
B, N = patch_valid.shape
kv_len = N # Initial key/value length is the number of patches
# Prepend prefix tokens if any
if num_prefix_tokens > 0:
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
# Create prefix validity tensor on the same device/dtype base as patch_valid
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool)
# Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
kv_len = patch_valid.shape[1]
kv_len += num_prefix_tokens # Update total key/value sequence length
q_len = q_len if q_len is not None else kv_len
if symmetric:
# Symmetric mask is True where BOTH query and key are valid
mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)
mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len]
else:
# Expanded mask
q_len = q_len or kv_len
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len)
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len).to(dtype)
# Create the float mask and apply masking using additive mask convention
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
return mask_float
@register_notrace_function
def create_pool_mask(
patch_valid:torch.Tensor,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
patch_valid = patch_valid.bool()
mask_bool = patch_valid[:, None, None, :]
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
# Fill with negative infinity where mask_bool is False (masked positions)
mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min)
return mask_float
@ -809,7 +787,12 @@ class VisionTransformerFlex(nn.Module):
) -> torch.Tensor:
if self.attn_pool is not None:
# For attention pooling, we need to pass the mask for NaFlex models
attn_mask = create_pool_mask(patch_valid, dtype=x.dtype)
attn_mask = create_attention_mask(
patch_valid,
symmetric=False,
q_len=1,
dtype=x.dtype,
)
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
return x
@ -839,7 +822,7 @@ class VisionTransformerFlex(nn.Module):
# For max pooling with mask
masked_x = x.clone()
masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
masked_max = masked_x.max(dim=1)[0]
# Combine average and max
@ -876,9 +859,7 @@ class VisionTransformerFlex(nn.Module):
Returns:
Model output tensor
"""
if isinstance(x, torch.Tensor):
patches = x
else:
if isinstance(x, Dict):
# Handle dictionary input from NaFlex collator
patch_coord = x['patch_coord']
patch_valid = x['patch_valid']
@ -893,6 +874,8 @@ class VisionTransformerFlex(nn.Module):
# patch = patch.reshape(3, h*16, w*16)
# from torchvision.utils import save_image
# save_image(patch, f'patch_{i}.jpg', normalize=True)
else:
patches = x
# Create attention mask if patch_type is provided
if patch_valid is not None: