[Fix] Fix AMP in Ascend and support using NPUJITCompile environment (#994)
* add npu device support * add npu device supportpull/997/head
parent
60872c38d4
commit
8177ef2aef
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -39,7 +40,8 @@ def is_npu_available() -> bool:
|
|||
|
||||
# Enable operator support for dynamic shape and
|
||||
# binary operator support on the NPU.
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
|
||||
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
|
||||
except Exception:
|
||||
return False
|
||||
return hasattr(torch, 'npu') and torch.npu.is_available()
|
||||
|
|
|
@ -126,6 +126,10 @@ def autocast(device_type: Optional[str] = None,
|
|||
|
||||
elif device_type == 'mlu':
|
||||
pass
|
||||
|
||||
elif device_type == 'npu':
|
||||
pass
|
||||
|
||||
else:
|
||||
# Device like MPS does not support fp16 training or testing.
|
||||
# If an inappropriate device is set and fp16 is enabled, an error
|
||||
|
|
Loading…
Reference in New Issue