mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Optimizations for pos embed resize, merge different mask helper fns
This commit is contained in:
parent
ea728f67fa
commit
c527c37969
@ -297,24 +297,24 @@ class FlexEmbeds(nn.Module):
|
||||
size_to_indices[k].append(bi)
|
||||
|
||||
# Handle each batch element separately with its own grid size
|
||||
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W
|
||||
for k, batch_indices in size_to_indices.items():
|
||||
h, w = k
|
||||
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
||||
# Interpolate only once for this (h, w)
|
||||
if (h == orig_h) and (w == orig_w):
|
||||
pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1)
|
||||
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
||||
else:
|
||||
pos_embed_resized = F.interpolate(
|
||||
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
|
||||
pos_embed_flat = F.interpolate(
|
||||
pos_embed_nchw,
|
||||
size=(h, w),
|
||||
mode=self.pos_embed_interp_mode,
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
pos_embed_flat = pos_embed_resized.permute(0, 2, 3, 1).reshape(h * w, -1)
|
||||
).flatten(2).transpose(1, 2)
|
||||
|
||||
seq_len = min(x.shape[1], pos_embed_flat.shape[0])
|
||||
x[batch_indices, :seq_len].add_(pos_embed_flat[:seq_len])
|
||||
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||
x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len])
|
||||
|
||||
def _apply_learned_pos_embed(
|
||||
self,
|
||||
@ -322,106 +322,84 @@ class FlexEmbeds(nn.Module):
|
||||
grid_size: List[int],
|
||||
):
|
||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||
if grid_size[0] != orig_h or grid_size[1] != orig_w:
|
||||
if grid_size[0] == orig_h or grid_size[1] == orig_w:
|
||||
# No resize needed, just flatten
|
||||
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
||||
else:
|
||||
# Resize if needed - directly using F.interpolate
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed_flat = F.interpolate(
|
||||
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
|
||||
size=grid_size,
|
||||
mode=self.pos_embed_interp_mode,
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
# Convert back and flatten
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||
pos_embed = pos_embed.reshape(1, grid_size[0] * grid_size[1], -1)
|
||||
).flatten(2).transpose(1, 2)
|
||||
|
||||
else:
|
||||
# No resize needed, just flatten
|
||||
pos_embed = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
||||
|
||||
x.add_(pos_embed)
|
||||
x.add_(pos_embed_flat)
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def create_attention_mask(
|
||||
patch_valid: torch.Tensor,
|
||||
num_prefix_tokens: int = 0,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""Create attention mask from patch type information.
|
||||
|
||||
Used for NaFlex mode to handle variable token counts and padding tokens.
|
||||
|
||||
Args:
|
||||
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
|
||||
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
|
||||
dtype: Dtype of the attention mask
|
||||
|
||||
Returns:
|
||||
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
|
||||
or None if patch_type is None
|
||||
"""
|
||||
patch_valid = patch_valid.to(torch.bool)
|
||||
B = patch_valid.shape[0]
|
||||
|
||||
if num_prefix_tokens > 0:
|
||||
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
|
||||
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
||||
|
||||
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
|
||||
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
|
||||
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
|
||||
|
||||
return mask_float
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def create_attention_mask2(
|
||||
patch_valid: torch.Tensor,
|
||||
num_prefix_tokens: int = 0,
|
||||
symmetric: bool = True,
|
||||
q_len: Optional[int] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Create expanded attention mask from patch validity info.
|
||||
) -> torch.Tensor:
|
||||
"""Creates an attention mask from patch validity information.
|
||||
|
||||
Supports two modes controlled by `symmetric`:
|
||||
1. `symmetric=True` (default): Creates a symmetric mask of shape
|
||||
[B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
|
||||
both token i and token j are valid. Suitable for standard self-attention.
|
||||
2. `symmetric=False`: Creates a potentially non-square mask of shape
|
||||
[B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
|
||||
the key/value token k is valid. Query token validity is not checked
|
||||
in the mask itself. Useful for cross-attention or specific self-attention
|
||||
implementations `q_len` can be specified.
|
||||
|
||||
Used for NaFlex mode to handle variable token counts and padding tokens.
|
||||
|
||||
Args:
|
||||
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
|
||||
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
|
||||
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
|
||||
q_len: Length override for query sequence
|
||||
dtype: Dtype of the attention mask
|
||||
to prepend, which are always considered valid.
|
||||
symmetric: If True, create a symmetric mask.
|
||||
If False, create an expanded mask based only on key/value validity.
|
||||
q_len: Query sequence length override. Only used when `symmetric` is False.
|
||||
Defaults to the key/value sequence length (`kv_len`) if None.
|
||||
dtype: Dtype of the output attention mask (e.g., torch.float32).
|
||||
|
||||
Returns:
|
||||
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
|
||||
or None if patch_type is None
|
||||
Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
|
||||
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
|
||||
or [B, 1, q_len, kv_len] if symmetric=False.
|
||||
"""
|
||||
patch_valid = patch_valid.bool()
|
||||
B, kv_len = patch_valid.shape
|
||||
patch_valid = patch_valid.bool() # Ensure boolean type
|
||||
B, N = patch_valid.shape
|
||||
kv_len = N # Initial key/value length is the number of patches
|
||||
|
||||
# Prepend prefix tokens if any
|
||||
if num_prefix_tokens > 0:
|
||||
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
|
||||
# Create prefix validity tensor on the same device/dtype base as patch_valid
|
||||
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool)
|
||||
# Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
|
||||
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
||||
kv_len = patch_valid.shape[1]
|
||||
kv_len += num_prefix_tokens # Update total key/value sequence length
|
||||
|
||||
q_len = q_len if q_len is not None else kv_len
|
||||
if symmetric:
|
||||
# Symmetric mask is True where BOTH query and key are valid
|
||||
mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)
|
||||
mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len]
|
||||
else:
|
||||
# Expanded mask
|
||||
q_len = q_len or kv_len
|
||||
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len)
|
||||
|
||||
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len).to(dtype)
|
||||
# Create the float mask and apply masking using additive mask convention
|
||||
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
|
||||
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
|
||||
|
||||
return mask_float
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def create_pool_mask(
|
||||
patch_valid:torch.Tensor,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
patch_valid = patch_valid.bool()
|
||||
mask_bool = patch_valid[:, None, None, :]
|
||||
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
|
||||
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
|
||||
# Fill with negative infinity where mask_bool is False (masked positions)
|
||||
mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min)
|
||||
|
||||
return mask_float
|
||||
|
||||
@ -809,7 +787,12 @@ class VisionTransformerFlex(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if self.attn_pool is not None:
|
||||
# For attention pooling, we need to pass the mask for NaFlex models
|
||||
attn_mask = create_pool_mask(patch_valid, dtype=x.dtype)
|
||||
attn_mask = create_attention_mask(
|
||||
patch_valid,
|
||||
symmetric=False,
|
||||
q_len=1,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
@ -839,7 +822,7 @@ class VisionTransformerFlex(nn.Module):
|
||||
|
||||
# For max pooling with mask
|
||||
masked_x = x.clone()
|
||||
masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min
|
||||
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||
masked_max = masked_x.max(dim=1)[0]
|
||||
|
||||
# Combine average and max
|
||||
@ -876,9 +859,7 @@ class VisionTransformerFlex(nn.Module):
|
||||
Returns:
|
||||
Model output tensor
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
patches = x
|
||||
else:
|
||||
if isinstance(x, Dict):
|
||||
# Handle dictionary input from NaFlex collator
|
||||
patch_coord = x['patch_coord']
|
||||
patch_valid = x['patch_valid']
|
||||
@ -893,6 +874,8 @@ class VisionTransformerFlex(nn.Module):
|
||||
# patch = patch.reshape(3, h*16, w*16)
|
||||
# from torchvision.utils import save_image
|
||||
# save_image(patch, f'patch_{i}.jpg', normalize=True)
|
||||
else:
|
||||
patches = x
|
||||
|
||||
# Create attention mask if patch_type is provided
|
||||
if patch_valid is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user