mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix bottleneck attn transpose typo, hopefully these train better now..
This commit is contained in:
parent
80075b0b8a
commit
b81e79aae9
@ -122,7 +122,7 @@ class BottleneckAttn(nn.Module):
|
|||||||
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
|
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
|
||||||
|
|
||||||
attn_out = attn_logits.softmax(dim=-1)
|
attn_out = attn_logits.softmax(dim=-1)
|
||||||
attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
|
attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
|
||||||
attn_out = self.pool(attn_out)
|
attn_out = self.pool(attn_out)
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user