rm with_data_parallel

pull/2833/head
duanyanhui 2023-05-04 17:22:12 +08:00 committed by cuicheng01
parent 6a55aac38c
commit 4a1bdf5857
1 changed files with 4 additions and 15 deletions

View File

@ -153,12 +153,6 @@ def create_strategy(config):
exec_strategy: exec strategy
"""
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = (
10000
if 'AMP' in config and config.AMP.get("level", "O1") == "O2" else 10)
fuse_op = True if 'AMP' in config else False
@ -172,7 +166,7 @@ def create_strategy(config):
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
build_strategy.enable_addto = enable_addto
return build_strategy, exec_strategy
return build_strategy
def dist_optimizer(config, optimizer):
@ -186,10 +180,9 @@ def dist_optimizer(config, optimizer):
Returns:
optimizer: a distributed optimizer
"""
build_strategy, exec_strategy = create_strategy(config)
build_strategy = create_strategy(config)
dist_strategy = DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.build_strategy = build_strategy
dist_strategy.nccl_comm_num = 1
@ -298,14 +291,10 @@ def compile(config, program, loss_name=None, share_prog=None):
Returns:
compiled_program(): a compiled program
"""
build_strategy, exec_strategy = create_strategy(config)
build_strategy = create_strategy(config)
compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel(
share_vars_from=share_prog,
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
program, build_strategy=build_strategy)
return compiled_program