diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 6d123cd4..e5d4a96b 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -250,9 +250,13 @@ class LiteMSA(nn.Module): k = self.kernel_func(k) v = F.pad(v, (0, 1), mode="constant", value=1.) - kv = k.transpose(-1, -2) @ v - out = q @ kv - out = out[..., :-1] / (out[..., -1:] + self.eps) + dtype = v.dtype + q, k, v = q.float(), k.float(), v.float() + with torch.amp.autocast(device_type=v.device.type, enabled=False): + kv = k.transpose(-1, -2) @ v + out = q @ kv + out = out[..., :-1] / (out[..., -1:] + self.eps) + out = out.to(dtype) # final projection out = out.transpose(-1, -2).reshape(B, -1, H, W)