efficientvit (mit) msa attention q/k/v ops need to be in float32 to train w/o NaN

samvit_fix_and_rope^2
Ross Wightman 2023-08-20 11:49:36 -07:00
parent e6aeb91ac1
commit dc18cda2e7
1 changed files with 7 additions and 3 deletions

View File

@ -250,9 +250,13 @@ class LiteMSA(nn.Module):
k = self.kernel_func(k)
v = F.pad(v, (0, 1), mode="constant", value=1.)
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
dtype = v.dtype
q, k, v = q.float(), k.float(), v.float()
with torch.amp.autocast(device_type=v.device.type, enabled=False):
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
out = out.to(dtype)
# final projection
out = out.transpose(-1, -2).reshape(B, -1, H, W)