rm with_data_parallel
parent
6a55aac38c
commit
4a1bdf5857
|
@ -153,12 +153,6 @@ def create_strategy(config):
|
|||
exec_strategy: exec strategy
|
||||
"""
|
||||
build_strategy = paddle.static.BuildStrategy()
|
||||
exec_strategy = paddle.static.ExecutionStrategy()
|
||||
|
||||
exec_strategy.num_threads = 1
|
||||
exec_strategy.num_iteration_per_drop_scope = (
|
||||
10000
|
||||
if 'AMP' in config and config.AMP.get("level", "O1") == "O2" else 10)
|
||||
|
||||
fuse_op = True if 'AMP' in config else False
|
||||
|
||||
|
@ -172,7 +166,7 @@ def create_strategy(config):
|
|||
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
|
||||
build_strategy.enable_addto = enable_addto
|
||||
|
||||
return build_strategy, exec_strategy
|
||||
return build_strategy
|
||||
|
||||
|
||||
def dist_optimizer(config, optimizer):
|
||||
|
@ -186,10 +180,9 @@ def dist_optimizer(config, optimizer):
|
|||
Returns:
|
||||
optimizer: a distributed optimizer
|
||||
"""
|
||||
build_strategy, exec_strategy = create_strategy(config)
|
||||
build_strategy = create_strategy(config)
|
||||
|
||||
dist_strategy = DistributedStrategy()
|
||||
dist_strategy.execution_strategy = exec_strategy
|
||||
dist_strategy.build_strategy = build_strategy
|
||||
|
||||
dist_strategy.nccl_comm_num = 1
|
||||
|
@ -298,14 +291,10 @@ def compile(config, program, loss_name=None, share_prog=None):
|
|||
Returns:
|
||||
compiled_program(): a compiled program
|
||||
"""
|
||||
build_strategy, exec_strategy = create_strategy(config)
|
||||
build_strategy = create_strategy(config)
|
||||
|
||||
compiled_program = paddle.static.CompiledProgram(
|
||||
program).with_data_parallel(
|
||||
share_vars_from=share_prog,
|
||||
loss_name=loss_name,
|
||||
build_strategy=build_strategy,
|
||||
exec_strategy=exec_strategy)
|
||||
program, build_strategy=build_strategy)
|
||||
|
||||
return compiled_program
|
||||
|
||||
|
|
Loading…
Reference in New Issue