diff --git a/timm/models/davit.py b/timm/models/davit.py index d4d6ad69..442ca620 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -126,9 +126,9 @@ class ChannelAttention(nn.Module): q, k, v = qkv.unbind(0) k = k * self.scale - attention = k.transpose(-1, -2) @ v - attention = attention.softmax(dim=-1) - x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + attn = k.transpose(-1, -2) @ v + attn = attn.softmax(dim=-1) + x = (attn @ q.transpose(-1, -2)).transpose(-1, -2) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x