From 2ad75e802369c9db7a04184e488dacf544cc39b5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 24 May 2025 12:37:39 -0700 Subject: [PATCH] Fix issue w/ MAP attention mask and no patch_valid --- timm/models/vision_transformer_flex.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py index 94c014a8..3398dc37 100644 --- a/timm/models/vision_transformer_flex.py +++ b/timm/models/vision_transformer_flex.py @@ -362,7 +362,7 @@ def create_attention_mask( symmetric: bool = True, q_len: Optional[int] = None, dtype: torch.dtype = torch.float32, -) -> torch.Tensor: +) -> Optional[torch.Tensor]: """Creates an attention mask from patch validity information. Supports two modes controlled by `symmetric`: @@ -392,6 +392,9 @@ def create_attention_mask( Shape is [B, 1, seq_len, seq_len] if symmetric=True, or [B, 1, q_len, kv_len] if symmetric=False. """ + if patch_valid is None: + return None + patch_valid = patch_valid.bool() # Ensure boolean type B, N = patch_valid.shape kv_len = N # Initial key/value length is the number of patches