From dc18cda2e74962d28df4c16c667745df03ffa0b5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 20 Aug 2023 11:49:36 -0700 Subject: [PATCH] efficientvit (mit) msa attention q/k/v ops need to be in float32 to train w/o NaN --- timm/models/efficientvit_mit.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 6d123cd4..e5d4a96b 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -250,9 +250,13 @@ class LiteMSA(nn.Module): k = self.kernel_func(k) v = F.pad(v, (0, 1), mode="constant", value=1.) - kv = k.transpose(-1, -2) @ v - out = q @ kv - out = out[..., :-1] / (out[..., -1:] + self.eps) + dtype = v.dtype + q, k, v = q.float(), k.float(), v.float() + with torch.amp.autocast(device_type=v.device.type, enabled=False): + kv = k.transpose(-1, -2) @ v + out = q @ kv + out = out[..., :-1] / (out[..., -1:] + self.eps) + out = out.to(dtype) # final projection out = out.transpose(-1, -2).reshape(B, -1, H, W)