mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix issue w/ MAP attention mask and no patch_valid
This commit is contained in:
parent
d7d3538335
commit
2ad75e8023
@ -362,7 +362,7 @@ def create_attention_mask(
|
|||||||
symmetric: bool = True,
|
symmetric: bool = True,
|
||||||
q_len: Optional[int] = None,
|
q_len: Optional[int] = None,
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
) -> torch.Tensor:
|
) -> Optional[torch.Tensor]:
|
||||||
"""Creates an attention mask from patch validity information.
|
"""Creates an attention mask from patch validity information.
|
||||||
|
|
||||||
Supports two modes controlled by `symmetric`:
|
Supports two modes controlled by `symmetric`:
|
||||||
@ -392,6 +392,9 @@ def create_attention_mask(
|
|||||||
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
|
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
|
||||||
or [B, 1, q_len, kv_len] if symmetric=False.
|
or [B, 1, q_len, kv_len] if symmetric=False.
|
||||||
"""
|
"""
|
||||||
|
if patch_valid is None:
|
||||||
|
return None
|
||||||
|
|
||||||
patch_valid = patch_valid.bool() # Ensure boolean type
|
patch_valid = patch_valid.bool() # Ensure boolean type
|
||||||
B, N = patch_valid.shape
|
B, N = patch_valid.shape
|
||||||
kv_len = N # Initial key/value length is the number of patches
|
kv_len = N # Initial key/value length is the number of patches
|
||||||
|
Loading…
x
Reference in New Issue
Block a user