mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix ParallelThingsBlock w/ attn_mask
This commit is contained in:
parent
9b23d6dea2
commit
6675590264
@ -332,7 +332,17 @@ class ParallelThingsBlock(nn.Module):
|
||||
])))
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0)
|
||||
if attn_mask is not None:
|
||||
attn_out = []
|
||||
for attn in self.attns:
|
||||
x_attn = attn.norm(x)
|
||||
x_attn = attn.attn(x_attn, attn_mask=attn_mask)
|
||||
x_attn = attn.ls(x_attn)
|
||||
x_attn = attn.drop_path(x_attn)
|
||||
attn_out.append(x_attn)
|
||||
x = x + torch.stack(attn_out).sum(dim=0)
|
||||
else:
|
||||
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
|
||||
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
|
||||
return x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user