diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index 44e92f71..a6575a42 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Optional import torch @@ -39,7 +40,8 @@ def is_npu_available() -> bool: # Enable operator support for dynamic shape and # binary operator support on the NPU. - torch.npu.set_compile_mode(jit_compile=False) + npu_jit_compile = bool(os.getenv('NPUJITCompile', False)) + torch.npu.set_compile_mode(jit_compile=npu_jit_compile) except Exception: return False return hasattr(torch, 'npu') and torch.npu.is_available() diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 8acf072a..33ab6bd2 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -126,6 +126,10 @@ def autocast(device_type: Optional[str] = None, elif device_type == 'mlu': pass + + elif device_type == 'npu': + pass + else: # Device like MPS does not support fp16 training or testing. # If an inappropriate device is set and fp16 is enabled, an error