mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix ParallelThingsBlock w/ attn_mask
This commit is contained in:
parent
9b23d6dea2
commit
6675590264
@ -332,7 +332,17 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
])))
|
])))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0)
|
if attn_mask is not None:
|
||||||
|
attn_out = []
|
||||||
|
for attn in self.attns:
|
||||||
|
x_attn = attn.norm(x)
|
||||||
|
x_attn = attn.attn(x_attn, attn_mask=attn_mask)
|
||||||
|
x_attn = attn.ls(x_attn)
|
||||||
|
x_attn = attn.drop_path(x_attn)
|
||||||
|
attn_out.append(x_attn)
|
||||||
|
x = x + torch.stack(attn_out).sum(dim=0)
|
||||||
|
else:
|
||||||
|
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
|
||||||
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
|
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user