mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix tracing by removing float cast, should end up float anyways
This commit is contained in:
parent
fb58a73033
commit
c715c724e7
@ -137,7 +137,7 @@ class ChannelAttentionV2(nn.Module):
|
|||||||
q, k, v = qkv.unbind(0)
|
q, k, v = qkv.unbind(0)
|
||||||
|
|
||||||
if self.dynamic_scale:
|
if self.dynamic_scale:
|
||||||
q = q * float(N) ** -0.5
|
q = q * N ** -0.5
|
||||||
else:
|
else:
|
||||||
q = q * self.head_dim ** -0.5
|
q = q * self.head_dim ** -0.5
|
||||||
attn = q.transpose(-1, -2) @ k
|
attn = q.transpose(-1, -2) @ k
|
||||||
|
Loading…
x
Reference in New Issue
Block a user