From 6675590264cba62b9d37437e4300180280de11d5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 8 Apr 2025 09:35:34 -0700 Subject: [PATCH] Fix ParallelThingsBlock w/ attn_mask --- timm/models/vision_transformer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 636b1902..c870e6c2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -332,7 +332,17 @@ class ParallelThingsBlock(nn.Module): ]))) def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0) + if attn_mask is not None: + attn_out = [] + for attn in self.attns: + x_attn = attn.norm(x) + x_attn = attn.attn(x_attn, attn_mask=attn_mask) + x_attn = attn.ls(x_attn) + x_attn = attn.drop_path(x_attn) + attn_out.append(x_attn) + x = x + torch.stack(attn_out).sum(dim=0) + else: + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) return x