Fix tracing by removing float cast, should end up float anyways

This commit is contained in:
Ross Wightman 2024-06-22 08:35:30 -07:00
parent fb58a73033
commit c715c724e7

View File

@ -137,7 +137,7 @@ class ChannelAttentionV2(nn.Module):
q, k, v = qkv.unbind(0)
if self.dynamic_scale:
q = q * float(N) ** -0.5
q = q * N ** -0.5
else:
q = q * self.head_dim ** -0.5
attn = q.transpose(-1, -2) @ k