[MLU]adapt mlu device for running dbnet network
parent
077196f3cb
commit
7851977157
|
@ -1,6 +1,7 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
use_xpu: false
|
||||
use_mlu: false
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
|
|
|
@ -114,7 +114,7 @@ def merge_config(config, opts):
|
|||
return config
|
||||
|
||||
|
||||
def check_device(use_gpu, use_xpu=False, use_npu=False):
|
||||
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
|
||||
"""
|
||||
Log error and exit when set use_gpu=true in paddlepaddle
|
||||
cpu version.
|
||||
|
@ -137,6 +137,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False):
|
|||
if use_npu and not paddle.device.is_compiled_with_npu():
|
||||
print(err.format("use_npu", "npu", "npu", "use_npu"))
|
||||
sys.exit(1)
|
||||
if use_mlu and not paddle.device.is_compiled_with_mlu():
|
||||
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
@ -618,6 +621,7 @@ def preprocess(is_train=False):
|
|||
use_gpu = config['Global'].get('use_gpu', False)
|
||||
use_xpu = config['Global'].get('use_xpu', False)
|
||||
use_npu = config['Global'].get('use_npu', False)
|
||||
use_mlu = config['Global'].get('use_mlu', False)
|
||||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
|
@ -632,10 +636,12 @@ def preprocess(is_train=False):
|
|||
device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
|
||||
elif use_npu:
|
||||
device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
|
||||
elif use_mlu:
|
||||
device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
|
||||
else:
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv()
|
||||
.dev_id) if use_gpu else 'cpu'
|
||||
check_device(use_gpu, use_xpu, use_npu)
|
||||
check_device(use_gpu, use_xpu, use_npu, use_mlu)
|
||||
|
||||
device = paddle.set_device(device)
|
||||
|
||||
|
|
|
@ -149,10 +149,11 @@ def main(config, device, logger, vdl_writer):
|
|||
amp_level = config["Global"].get("amp_level", 'O2')
|
||||
amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
|
||||
if use_amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||
'FLAGS_max_inplace_grad_add': 8,
|
||||
}
|
||||
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
|
||||
if paddle.is_compiled_with_cuda():
|
||||
AMP_RELATED_FLAGS_SETTING.update({
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
||||
})
|
||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
scale_loss = config["Global"].get("scale_loss", 1.0)
|
||||
use_dynamic_loss_scaling = config["Global"].get(
|
||||
|
|
Loading…
Reference in New Issue