efficientvit (mit) msa attention q/k/v ops need to be in float32 to train w/o NaN
parent
e6aeb91ac1
commit
dc18cda2e7
|
@ -250,9 +250,13 @@ class LiteMSA(nn.Module):
|
|||
k = self.kernel_func(k)
|
||||
v = F.pad(v, (0, 1), mode="constant", value=1.)
|
||||
|
||||
kv = k.transpose(-1, -2) @ v
|
||||
out = q @ kv
|
||||
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
||||
dtype = v.dtype
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
with torch.amp.autocast(device_type=v.device.type, enabled=False):
|
||||
kv = k.transpose(-1, -2) @ v
|
||||
out = q @ kv
|
||||
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
||||
out = out.to(dtype)
|
||||
|
||||
# final projection
|
||||
out = out.transpose(-1, -2).reshape(B, -1, H, W)
|
||||
|
|
Loading…
Reference in New Issue