optimizer fleet distributed strategy (#500)
parent
7b15ef1aea
commit
62fd192784
tools/static
|
@ -288,6 +288,60 @@ def create_optimizer(config):
|
|||
opt = OptimizerBuilder(config, **opt_config)
|
||||
return opt(lr), lr
|
||||
|
||||
def create_strategy(config):
|
||||
"""
|
||||
Create build strategy and exec strategy.
|
||||
|
||||
Args:
|
||||
config(dict): config
|
||||
|
||||
Returns:
|
||||
build_strategy: build strategy
|
||||
exec_strategy: exec strategy
|
||||
"""
|
||||
build_strategy = paddle.static.BuildStrategy()
|
||||
exec_strategy = paddle.static.ExecutionStrategy()
|
||||
|
||||
exec_strategy.num_threads = 1
|
||||
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
|
||||
'use_pure_fp16', False) else 10
|
||||
|
||||
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
|
||||
False)
|
||||
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
|
||||
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
|
||||
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
|
||||
enable_addto = config.get('enable_addto', fuse_op)
|
||||
|
||||
try:
|
||||
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle version 1.7.0 or higher is "
|
||||
"required when you want to fuse batch_norm and activation_op.")
|
||||
|
||||
try:
|
||||
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle version 1.7.0 or higher is "
|
||||
"required when you want to fuse elewise_add_act and activation_op.")
|
||||
|
||||
try:
|
||||
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle 2.0-rc or higher is "
|
||||
"required when you want to enable fuse_bn_add_act_ops strategy.")
|
||||
|
||||
try:
|
||||
build_strategy.enable_addto = enable_addto
|
||||
except Exception as e:
|
||||
logger.info("PaddlePaddle 2.0-rc or higher is "
|
||||
"required when you want to enable addto strategy.")
|
||||
return build_strategy, exec_strategy
|
||||
|
||||
|
||||
|
||||
def dist_optimizer(config, optimizer):
|
||||
"""
|
||||
|
@ -300,14 +354,15 @@ def dist_optimizer(config, optimizer):
|
|||
Returns:
|
||||
optimizer: a distributed optimizer
|
||||
"""
|
||||
exec_strategy = paddle.static.ExecutionStrategy()
|
||||
exec_strategy.num_threads = 3
|
||||
exec_strategy.num_iteration_per_drop_scope = 10
|
||||
build_strategy, exec_strategy = create_strategy(config)
|
||||
|
||||
dist_strategy = DistributedStrategy()
|
||||
dist_strategy.execution_strategy = exec_strategy
|
||||
dist_strategy.build_strategy = build_strategy
|
||||
|
||||
dist_strategy.nccl_comm_num = 1
|
||||
dist_strategy.fuse_all_reduce_ops = True
|
||||
dist_strategy.execution_strategy = exec_strategy
|
||||
dist_strategy.fuse_grad_size_in_MB = 16
|
||||
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
|
||||
|
||||
return optimizer
|
||||
|
@ -399,46 +454,7 @@ def compile(config, program, loss_name=None, share_prog=None):
|
|||
Returns:
|
||||
compiled_program(): a compiled program
|
||||
"""
|
||||
build_strategy = paddle.static.BuildStrategy()
|
||||
exec_strategy = paddle.static.ExecutionStrategy()
|
||||
|
||||
exec_strategy.num_threads = 1
|
||||
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
|
||||
'use_pure_fp16', False) else 10
|
||||
|
||||
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
|
||||
False)
|
||||
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
|
||||
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
|
||||
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
|
||||
enable_addto = config.get('enable_addto', fuse_op)
|
||||
|
||||
try:
|
||||
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle version 1.7.0 or higher is "
|
||||
"required when you want to fuse batch_norm and activation_op.")
|
||||
|
||||
try:
|
||||
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle version 1.7.0 or higher is "
|
||||
"required when you want to fuse elewise_add_act and activation_op.")
|
||||
|
||||
try:
|
||||
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle 2.0-rc or higher is "
|
||||
"required when you want to enable fuse_bn_add_act_ops strategy.")
|
||||
|
||||
try:
|
||||
build_strategy.enable_addto = enable_addto
|
||||
except Exception as e:
|
||||
logger.info("PaddlePaddle 2.0-rc or higher is "
|
||||
"required when you want to enable addto strategy.")
|
||||
build_strategy, exec_strategy = create_strategy(config)
|
||||
|
||||
compiled_program = paddle.static.CompiledProgram(
|
||||
program).with_data_parallel(
|
||||
|
|
Loading…
Reference in New Issue