From fed0e3821a92777b7106f8b6773afc99e4883e9c Mon Sep 17 00:00:00 2001 From: Yinlei Sun Date: Sat, 6 May 2023 17:17:32 +0800 Subject: [PATCH] [Enhance] Enable full precision training on Ascend NPU (#1109) --- mmengine/device/__init__.py | 6 ++++-- mmengine/device/utils.py | 8 ++++++++ mmengine/optim/optimizer/builder.py | 8 ++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py index c6b9d0af..623aa0b8 100644 --- a/mmengine/device/__init__.py +++ b/mmengine/device/__init__.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .utils import (get_device, get_max_cuda_memory, is_cuda_available, - is_mlu_available, is_mps_available, is_npu_available) + is_mlu_available, is_mps_available, is_npu_available, + is_npu_support_full_precision) __all__ = [ 'get_max_cuda_memory', 'get_device', 'is_cuda_available', - 'is_mlu_available', 'is_mps_available', 'is_npu_available' + 'is_mlu_available', 'is_mps_available', 'is_npu_available', + 'is_npu_support_full_precision' ] diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index 5858857b..63c90633 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -6,6 +6,7 @@ import torch try: import torch_npu # noqa: F401 + import torch_npu.npu.utils as npu_utils # Enable operator support for dynamic shape and # binary operator support on the NPU. @@ -62,6 +63,13 @@ def is_mps_available() -> bool: return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() +def is_npu_support_full_precision() -> bool: + """Returns True if npu devices support full precision training.""" + version_of_support_full_precision = 220 + return IS_NPU_AVAILABLE and npu_utils.get_soc_version( + ) >= version_of_support_full_precision + + DEVICE = 'cpu' if is_npu_available(): DEVICE = 'npu' diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 44c7a713..65782ff1 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from mmengine.config import Config, ConfigDict -from mmengine.device import is_npu_available +from mmengine.device import is_npu_available, is_npu_support_full_precision from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS from .optimizer_wrapper import OptimWrapper @@ -128,9 +128,9 @@ def build_optim_wrapper(model: nn.Module, paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) # Since the current generation of NPU(Ascend 910) only supports - # mixed precision training, here we turn on mixed precision by default - # on the NPU to make the training normal - if is_npu_available(): + # mixed precision training, here we turn on mixed precision + # to make the training normal + if is_npu_available() and not is_npu_support_full_precision(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(