mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Enable full precision training on Ascend NPU (#1109)
This commit is contained in:
parent
6cd7a43a7f
commit
fed0e3821a
@ -1,8 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||
is_mlu_available, is_mps_available, is_npu_available)
|
||||
is_mlu_available, is_mps_available, is_npu_available,
|
||||
is_npu_support_full_precision)
|
||||
|
||||
__all__ = [
|
||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||
'is_mlu_available', 'is_mps_available', 'is_npu_available'
|
||||
'is_mlu_available', 'is_mps_available', 'is_npu_available',
|
||||
'is_npu_support_full_precision'
|
||||
]
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
import torch_npu.npu.utils as npu_utils
|
||||
|
||||
# Enable operator support for dynamic shape and
|
||||
# binary operator support on the NPU.
|
||||
@ -62,6 +63,13 @@ def is_mps_available() -> bool:
|
||||
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
|
||||
|
||||
def is_npu_support_full_precision() -> bool:
|
||||
"""Returns True if npu devices support full precision training."""
|
||||
version_of_support_full_precision = 220
|
||||
return IS_NPU_AVAILABLE and npu_utils.get_soc_version(
|
||||
) >= version_of_support_full_precision
|
||||
|
||||
|
||||
DEVICE = 'cpu'
|
||||
if is_npu_available():
|
||||
DEVICE = 'npu'
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.device import is_npu_available
|
||||
from mmengine.device import is_npu_available, is_npu_support_full_precision
|
||||
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
@ -128,9 +128,9 @@ def build_optim_wrapper(model: nn.Module,
|
||||
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
||||
|
||||
# Since the current generation of NPU(Ascend 910) only supports
|
||||
# mixed precision training, here we turn on mixed precision by default
|
||||
# on the NPU to make the training normal
|
||||
if is_npu_available():
|
||||
# mixed precision training, here we turn on mixed precision
|
||||
# to make the training normal
|
||||
if is_npu_available() and not is_npu_support_full_precision():
|
||||
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
|
||||
|
||||
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
||||
|
Loading…
x
Reference in New Issue
Block a user