mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Enable full precision training on Ascend NPU (#1109)
This commit is contained in:
parent
6cd7a43a7f
commit
fed0e3821a
@ -1,8 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||||
is_mlu_available, is_mps_available, is_npu_available)
|
is_mlu_available, is_mps_available, is_npu_available,
|
||||||
|
is_npu_support_full_precision)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||||
'is_mlu_available', 'is_mps_available', 'is_npu_available'
|
'is_mlu_available', 'is_mps_available', 'is_npu_available',
|
||||||
|
'is_npu_support_full_precision'
|
||||||
]
|
]
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
import torch_npu.npu.utils as npu_utils
|
||||||
|
|
||||||
# Enable operator support for dynamic shape and
|
# Enable operator support for dynamic shape and
|
||||||
# binary operator support on the NPU.
|
# binary operator support on the NPU.
|
||||||
@ -62,6 +63,13 @@ def is_mps_available() -> bool:
|
|||||||
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
def is_npu_support_full_precision() -> bool:
|
||||||
|
"""Returns True if npu devices support full precision training."""
|
||||||
|
version_of_support_full_precision = 220
|
||||||
|
return IS_NPU_AVAILABLE and npu_utils.get_soc_version(
|
||||||
|
) >= version_of_support_full_precision
|
||||||
|
|
||||||
|
|
||||||
DEVICE = 'cpu'
|
DEVICE = 'cpu'
|
||||||
if is_npu_available():
|
if is_npu_available():
|
||||||
DEVICE = 'npu'
|
DEVICE = 'npu'
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from mmengine.config import Config, ConfigDict
|
from mmengine.config import Config, ConfigDict
|
||||||
from mmengine.device import is_npu_available
|
from mmengine.device import is_npu_available, is_npu_support_full_precision
|
||||||
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
|
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
|
||||||
from .optimizer_wrapper import OptimWrapper
|
from .optimizer_wrapper import OptimWrapper
|
||||||
|
|
||||||
@ -128,9 +128,9 @@ def build_optim_wrapper(model: nn.Module,
|
|||||||
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
||||||
|
|
||||||
# Since the current generation of NPU(Ascend 910) only supports
|
# Since the current generation of NPU(Ascend 910) only supports
|
||||||
# mixed precision training, here we turn on mixed precision by default
|
# mixed precision training, here we turn on mixed precision
|
||||||
# on the NPU to make the training normal
|
# to make the training normal
|
||||||
if is_npu_available():
|
if is_npu_available() and not is_npu_support_full_precision():
|
||||||
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
|
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
|
||||||
|
|
||||||
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user