Fix ParallelThingsBlock w/ attn_mask

This commit is contained in:
Ross Wightman 2025-04-08 09:35:34 -07:00
parent 9b23d6dea2
commit 6675590264

View File

@ -332,7 +332,17 @@ class ParallelThingsBlock(nn.Module):
])))
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0)
if attn_mask is not None:
attn_out = []
for attn in self.attns:
x_attn = attn.norm(x)
x_attn = attn.attn(x_attn, attn_mask=attn_mask)
x_attn = attn.ls(x_attn)
x_attn = attn.drop_path(x_attn)
attn_out.append(x_attn)
x = x + torch.stack(attn_out).sum(dim=0)
else:
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
return x