[Enhance] Enable full precision training on Ascend NPU (#1109)

This commit is contained in:
Yinlei Sun 2023-05-06 17:17:32 +08:00 committed by GitHub
parent 6cd7a43a7f
commit fed0e3821a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 6 deletions

View File

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available, from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available, is_mps_available, is_npu_available) is_mlu_available, is_mps_available, is_npu_available,
is_npu_support_full_precision)
__all__ = [ __all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available', 'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available', 'is_npu_available' 'is_mlu_available', 'is_mps_available', 'is_npu_available',
'is_npu_support_full_precision'
] ]

View File

@ -6,6 +6,7 @@ import torch
try: try:
import torch_npu # noqa: F401 import torch_npu # noqa: F401
import torch_npu.npu.utils as npu_utils
# Enable operator support for dynamic shape and # Enable operator support for dynamic shape and
# binary operator support on the NPU. # binary operator support on the NPU.
@ -62,6 +63,13 @@ def is_mps_available() -> bool:
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
def is_npu_support_full_precision() -> bool:
"""Returns True if npu devices support full precision training."""
version_of_support_full_precision = 220
return IS_NPU_AVAILABLE and npu_utils.get_soc_version(
) >= version_of_support_full_precision
DEVICE = 'cpu' DEVICE = 'cpu'
if is_npu_available(): if is_npu_available():
DEVICE = 'npu' DEVICE = 'npu'

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available from mmengine.device import is_npu_available, is_npu_support_full_precision
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
from .optimizer_wrapper import OptimWrapper from .optimizer_wrapper import OptimWrapper
@ -128,9 +128,9 @@ def build_optim_wrapper(model: nn.Module,
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
# Since the current generation of NPU(Ascend 910) only supports # Since the current generation of NPU(Ascend 910) only supports
# mixed precision training, here we turn on mixed precision by default # mixed precision training, here we turn on mixed precision
# on the NPU to make the training normal # to make the training normal
if is_npu_available(): if is_npu_available() and not is_npu_support_full_precision():
optim_wrapper_cfg['type'] = 'AmpOptimWrapper' optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(