mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Merge pull request #1755 from qipengh/hqp_support_mlu
feat: support mlu device and amp of mlu
This commit is contained in:
commit
b51e58bfe6
@ -92,7 +92,7 @@ class Engine(object):
|
|||||||
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
|
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
|
||||||
|
|
||||||
# set device
|
# set device
|
||||||
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu"]
|
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu", "mlu"]
|
||||||
self.device = paddle.set_device(self.config["Global"]["device"])
|
self.device = paddle.set_device(self.config["Global"]["device"])
|
||||||
logger.info('train with paddle {} and device {}'.format(
|
logger.info('train with paddle {} and device {}'.format(
|
||||||
paddle.__version__, self.device))
|
paddle.__version__, self.device))
|
||||||
@ -108,9 +108,12 @@ class Engine(object):
|
|||||||
self.use_dynamic_loss_scaling = False
|
self.use_dynamic_loss_scaling = False
|
||||||
if self.amp:
|
if self.amp:
|
||||||
AMP_RELATED_FLAGS_SETTING = {
|
AMP_RELATED_FLAGS_SETTING = {
|
||||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
|
||||||
'FLAGS_max_inplace_grad_add': 8,
|
'FLAGS_max_inplace_grad_add': 8,
|
||||||
}
|
}
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
AMP_RELATED_FLAGS_SETTING.update({
|
||||||
|
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
||||||
|
})
|
||||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||||
|
|
||||||
if "class_num" in config["Global"]:
|
if "class_num" in config["Global"]:
|
||||||
|
@ -91,9 +91,10 @@ def main(args):
|
|||||||
|
|
||||||
use_xpu = global_config.get("use_xpu", False)
|
use_xpu = global_config.get("use_xpu", False)
|
||||||
use_npu = global_config.get("use_npu", False)
|
use_npu = global_config.get("use_npu", False)
|
||||||
|
use_mlu = global_config.get("use_mlu", False)
|
||||||
assert (
|
assert (
|
||||||
use_gpu and use_xpu and use_npu
|
use_gpu and use_xpu and use_npu and use_mlu
|
||||||
) is not True, "gpu, xpu and npu can not be true in the same time in static mode!"
|
) is not True, "gpu, xpu, npu and mlu can not be true in the same time in static mode!"
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
device = paddle.set_device('gpu')
|
device = paddle.set_device('gpu')
|
||||||
@ -101,6 +102,8 @@ def main(args):
|
|||||||
device = paddle.set_device('xpu')
|
device = paddle.set_device('xpu')
|
||||||
elif use_npu:
|
elif use_npu:
|
||||||
device = paddle.set_device('npu')
|
device = paddle.set_device('npu')
|
||||||
|
elif use_mlu:
|
||||||
|
device = paddle.set_device('mlu')
|
||||||
else:
|
else:
|
||||||
device = paddle.set_device('cpu')
|
device = paddle.set_device('cpu')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user