Fix static training speed (#1590)
* fix training speed * update config setting methodpull/1594/head
parent
dadf50d047
commit
0f35f706b6
ppcls
configs/ImageNet/ResNet
static
|
@ -16,6 +16,7 @@ Global:
|
|||
save_inference_dir: ./inference
|
||||
# training model under @to_static
|
||||
to_static: False
|
||||
use_dali: True
|
||||
|
||||
# mixed precision training
|
||||
AMP:
|
||||
|
|
|
@ -81,14 +81,13 @@ def main(args):
|
|||
# amp related config
|
||||
if 'AMP' in config:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
'FLAGS_cudnn_exhaustive_search': "1",
|
||||
'FLAGS_conv_workspace_size_limit': "1500",
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': "1",
|
||||
'FLAGS_max_indevice_grad_add': "8",
|
||||
"FLAGS_cudnn_batchnorm_spatial_persistent": "1",
|
||||
'FLAGS_cudnn_exhaustive_search': 1,
|
||||
'FLAGS_conv_workspace_size_limit': 1500,
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||
'FLAGS_max_inplace_grad_add': 8,
|
||||
}
|
||||
for k in AMP_RELATED_FLAGS_SETTING:
|
||||
os.environ[k] = AMP_RELATED_FLAGS_SETTING[k]
|
||||
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1'
|
||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
|
||||
use_xpu = global_config.get("use_xpu", False)
|
||||
use_npu = global_config.get("use_npu", False)
|
||||
|
|
Loading…
Reference in New Issue