[Fix] Fix AMP in Ascend and support using NPUJITCompile environment (#994)

* add npu device support

* add npu device support
pull/997/head
luomaoling 2023-03-13 19:08:50 +08:00 committed by GitHub
parent 60872c38d4
commit 8177ef2aef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional
import torch
@ -39,7 +40,8 @@ def is_npu_available() -> bool:
# Enable operator support for dynamic shape and
# binary operator support on the NPU.
torch.npu.set_compile_mode(jit_compile=False)
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
except Exception:
return False
return hasattr(torch, 'npu') and torch.npu.is_available()

View File

@ -126,6 +126,10 @@ def autocast(device_type: Optional[str] = None,
elif device_type == 'mlu':
pass
elif device_type == 'npu':
pass
else:
# Device like MPS does not support fp16 training or testing.
# If an inappropriate device is set and fp16 is enabled, an error