Fix issue w/ MAP attention mask and no patch_valid

This commit is contained in:
Ross Wightman 2025-05-24 12:37:39 -07:00
parent d7d3538335
commit 2ad75e8023

View File

@ -362,7 +362,7 @@ def create_attention_mask(
symmetric: bool = True, symmetric: bool = True,
q_len: Optional[int] = None, q_len: Optional[int] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
) -> torch.Tensor: ) -> Optional[torch.Tensor]:
"""Creates an attention mask from patch validity information. """Creates an attention mask from patch validity information.
Supports two modes controlled by `symmetric`: Supports two modes controlled by `symmetric`:
@ -392,6 +392,9 @@ def create_attention_mask(
Shape is [B, 1, seq_len, seq_len] if symmetric=True, Shape is [B, 1, seq_len, seq_len] if symmetric=True,
or [B, 1, q_len, kv_len] if symmetric=False. or [B, 1, q_len, kv_len] if symmetric=False.
""" """
if patch_valid is None:
return None
patch_valid = patch_valid.bool() # Ensure boolean type patch_valid = patch_valid.bool() # Ensure boolean type
B, N = patch_valid.shape B, N = patch_valid.shape
kv_len = N # Initial key/value length is the number of patches kv_len = N # Initial key/value length is the number of patches