mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Optimizations for pos embed resize, merge different mask helper fns
This commit is contained in:
parent
ea728f67fa
commit
c527c37969
@ -297,24 +297,24 @@ class FlexEmbeds(nn.Module):
|
|||||||
size_to_indices[k].append(bi)
|
size_to_indices[k].append(bi)
|
||||||
|
|
||||||
# Handle each batch element separately with its own grid size
|
# Handle each batch element separately with its own grid size
|
||||||
|
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W
|
||||||
for k, batch_indices in size_to_indices.items():
|
for k, batch_indices in size_to_indices.items():
|
||||||
h, w = k
|
h, w = k
|
||||||
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
||||||
# Interpolate only once for this (h, w)
|
# Interpolate only once for this (h, w)
|
||||||
if (h == orig_h) and (w == orig_w):
|
if (h == orig_h) and (w == orig_w):
|
||||||
pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1)
|
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
||||||
else:
|
else:
|
||||||
pos_embed_resized = F.interpolate(
|
pos_embed_flat = F.interpolate(
|
||||||
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
|
pos_embed_nchw,
|
||||||
size=(h, w),
|
size=(h, w),
|
||||||
mode=self.pos_embed_interp_mode,
|
mode=self.pos_embed_interp_mode,
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
antialias=True,
|
antialias=True,
|
||||||
)
|
).flatten(2).transpose(1, 2)
|
||||||
pos_embed_flat = pos_embed_resized.permute(0, 2, 3, 1).reshape(h * w, -1)
|
|
||||||
|
|
||||||
seq_len = min(x.shape[1], pos_embed_flat.shape[0])
|
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||||
x[batch_indices, :seq_len].add_(pos_embed_flat[:seq_len])
|
x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len])
|
||||||
|
|
||||||
def _apply_learned_pos_embed(
|
def _apply_learned_pos_embed(
|
||||||
self,
|
self,
|
||||||
@ -322,106 +322,84 @@ class FlexEmbeds(nn.Module):
|
|||||||
grid_size: List[int],
|
grid_size: List[int],
|
||||||
):
|
):
|
||||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||||
if grid_size[0] != orig_h or grid_size[1] != orig_w:
|
if grid_size[0] == orig_h or grid_size[1] == orig_w:
|
||||||
|
# No resize needed, just flatten
|
||||||
|
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
||||||
|
else:
|
||||||
# Resize if needed - directly using F.interpolate
|
# Resize if needed - directly using F.interpolate
|
||||||
pos_embed = F.interpolate(
|
pos_embed_flat = F.interpolate(
|
||||||
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
|
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
|
||||||
size=grid_size,
|
size=grid_size,
|
||||||
mode=self.pos_embed_interp_mode,
|
mode=self.pos_embed_interp_mode,
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
antialias=True,
|
antialias=True,
|
||||||
)
|
).flatten(2).transpose(1, 2)
|
||||||
# Convert back and flatten
|
|
||||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
|
||||||
pos_embed = pos_embed.reshape(1, grid_size[0] * grid_size[1], -1)
|
|
||||||
|
|
||||||
else:
|
x.add_(pos_embed_flat)
|
||||||
# No resize needed, just flatten
|
|
||||||
pos_embed = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
|
||||||
|
|
||||||
x.add_(pos_embed)
|
|
||||||
|
|
||||||
|
|
||||||
@register_notrace_function
|
@register_notrace_function
|
||||||
def create_attention_mask(
|
def create_attention_mask(
|
||||||
patch_valid: torch.Tensor,
|
|
||||||
num_prefix_tokens: int = 0,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Create attention mask from patch type information.
|
|
||||||
|
|
||||||
Used for NaFlex mode to handle variable token counts and padding tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
|
|
||||||
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
|
|
||||||
dtype: Dtype of the attention mask
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
|
|
||||||
or None if patch_type is None
|
|
||||||
"""
|
|
||||||
patch_valid = patch_valid.to(torch.bool)
|
|
||||||
B = patch_valid.shape[0]
|
|
||||||
|
|
||||||
if num_prefix_tokens > 0:
|
|
||||||
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
|
|
||||||
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
|
||||||
|
|
||||||
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
|
|
||||||
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
|
|
||||||
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
|
|
||||||
|
|
||||||
return mask_float
|
|
||||||
|
|
||||||
|
|
||||||
@register_notrace_function
|
|
||||||
def create_attention_mask2(
|
|
||||||
patch_valid: torch.Tensor,
|
patch_valid: torch.Tensor,
|
||||||
num_prefix_tokens: int = 0,
|
num_prefix_tokens: int = 0,
|
||||||
|
symmetric: bool = True,
|
||||||
q_len: Optional[int] = None,
|
q_len: Optional[int] = None,
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
"""Create expanded attention mask from patch validity info.
|
"""Creates an attention mask from patch validity information.
|
||||||
|
|
||||||
|
Supports two modes controlled by `symmetric`:
|
||||||
|
1. `symmetric=True` (default): Creates a symmetric mask of shape
|
||||||
|
[B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
|
||||||
|
both token i and token j are valid. Suitable for standard self-attention.
|
||||||
|
2. `symmetric=False`: Creates a potentially non-square mask of shape
|
||||||
|
[B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
|
||||||
|
the key/value token k is valid. Query token validity is not checked
|
||||||
|
in the mask itself. Useful for cross-attention or specific self-attention
|
||||||
|
implementations `q_len` can be specified.
|
||||||
|
|
||||||
Used for NaFlex mode to handle variable token counts and padding tokens.
|
Used for NaFlex mode to handle variable token counts and padding tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
|
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
|
||||||
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
|
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
|
||||||
q_len: Length override for query sequence
|
to prepend, which are always considered valid.
|
||||||
dtype: Dtype of the attention mask
|
symmetric: If True, create a symmetric mask.
|
||||||
|
If False, create an expanded mask based only on key/value validity.
|
||||||
|
q_len: Query sequence length override. Only used when `symmetric` is False.
|
||||||
|
Defaults to the key/value sequence length (`kv_len`) if None.
|
||||||
|
dtype: Dtype of the output attention mask (e.g., torch.float32).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
|
Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
|
||||||
or None if patch_type is None
|
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
|
||||||
|
or [B, 1, q_len, kv_len] if symmetric=False.
|
||||||
"""
|
"""
|
||||||
patch_valid = patch_valid.bool()
|
patch_valid = patch_valid.bool() # Ensure boolean type
|
||||||
B, kv_len = patch_valid.shape
|
B, N = patch_valid.shape
|
||||||
|
kv_len = N # Initial key/value length is the number of patches
|
||||||
|
|
||||||
|
# Prepend prefix tokens if any
|
||||||
if num_prefix_tokens > 0:
|
if num_prefix_tokens > 0:
|
||||||
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
|
# Create prefix validity tensor on the same device/dtype base as patch_valid
|
||||||
|
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool)
|
||||||
|
# Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
|
||||||
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
||||||
kv_len = patch_valid.shape[1]
|
kv_len += num_prefix_tokens # Update total key/value sequence length
|
||||||
|
|
||||||
q_len = q_len if q_len is not None else kv_len
|
if symmetric:
|
||||||
|
# Symmetric mask is True where BOTH query and key are valid
|
||||||
|
mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)
|
||||||
|
mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len]
|
||||||
|
else:
|
||||||
|
# Expanded mask
|
||||||
|
q_len = q_len or kv_len
|
||||||
|
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len)
|
||||||
|
|
||||||
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len).to(dtype)
|
# Create the float mask and apply masking using additive mask convention
|
||||||
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
|
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
|
||||||
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
|
# Fill with negative infinity where mask_bool is False (masked positions)
|
||||||
|
mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min)
|
||||||
return mask_float
|
|
||||||
|
|
||||||
|
|
||||||
@register_notrace_function
|
|
||||||
def create_pool_mask(
|
|
||||||
patch_valid:torch.Tensor,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
patch_valid = patch_valid.bool()
|
|
||||||
mask_bool = patch_valid[:, None, None, :]
|
|
||||||
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
|
|
||||||
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
|
|
||||||
|
|
||||||
return mask_float
|
return mask_float
|
||||||
|
|
||||||
@ -809,7 +787,12 @@ class VisionTransformerFlex(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.attn_pool is not None:
|
if self.attn_pool is not None:
|
||||||
# For attention pooling, we need to pass the mask for NaFlex models
|
# For attention pooling, we need to pass the mask for NaFlex models
|
||||||
attn_mask = create_pool_mask(patch_valid, dtype=x.dtype)
|
attn_mask = create_attention_mask(
|
||||||
|
patch_valid,
|
||||||
|
symmetric=False,
|
||||||
|
q_len=1,
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
|
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -839,7 +822,7 @@ class VisionTransformerFlex(nn.Module):
|
|||||||
|
|
||||||
# For max pooling with mask
|
# For max pooling with mask
|
||||||
masked_x = x.clone()
|
masked_x = x.clone()
|
||||||
masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min
|
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||||
masked_max = masked_x.max(dim=1)[0]
|
masked_max = masked_x.max(dim=1)[0]
|
||||||
|
|
||||||
# Combine average and max
|
# Combine average and max
|
||||||
@ -876,9 +859,7 @@ class VisionTransformerFlex(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Model output tensor
|
Model output tensor
|
||||||
"""
|
"""
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, Dict):
|
||||||
patches = x
|
|
||||||
else:
|
|
||||||
# Handle dictionary input from NaFlex collator
|
# Handle dictionary input from NaFlex collator
|
||||||
patch_coord = x['patch_coord']
|
patch_coord = x['patch_coord']
|
||||||
patch_valid = x['patch_valid']
|
patch_valid = x['patch_valid']
|
||||||
@ -893,6 +874,8 @@ class VisionTransformerFlex(nn.Module):
|
|||||||
# patch = patch.reshape(3, h*16, w*16)
|
# patch = patch.reshape(3, h*16, w*16)
|
||||||
# from torchvision.utils import save_image
|
# from torchvision.utils import save_image
|
||||||
# save_image(patch, f'patch_{i}.jpg', normalize=True)
|
# save_image(patch, f'patch_{i}.jpg', normalize=True)
|
||||||
|
else:
|
||||||
|
patches = x
|
||||||
|
|
||||||
# Create attention mask if patch_type is provided
|
# Create attention mask if patch_type is provided
|
||||||
if patch_valid is not None:
|
if patch_valid is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user