mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Fix static training speed (#1590)
* fix training speed * update config setting method
This commit is contained in:
parent
7eec82b87e
commit
bce0d041a7
@ -16,6 +16,7 @@ Global:
|
|||||||
save_inference_dir: ./inference
|
save_inference_dir: ./inference
|
||||||
# training model under @to_static
|
# training model under @to_static
|
||||||
to_static: False
|
to_static: False
|
||||||
|
use_dali: True
|
||||||
|
|
||||||
# mixed precision training
|
# mixed precision training
|
||||||
AMP:
|
AMP:
|
||||||
|
@ -81,14 +81,13 @@ def main(args):
|
|||||||
# amp related config
|
# amp related config
|
||||||
if 'AMP' in config:
|
if 'AMP' in config:
|
||||||
AMP_RELATED_FLAGS_SETTING = {
|
AMP_RELATED_FLAGS_SETTING = {
|
||||||
'FLAGS_cudnn_exhaustive_search': "1",
|
'FLAGS_cudnn_exhaustive_search': 1,
|
||||||
'FLAGS_conv_workspace_size_limit': "1500",
|
'FLAGS_conv_workspace_size_limit': 1500,
|
||||||
'FLAGS_cudnn_batchnorm_spatial_persistent': "1",
|
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||||
'FLAGS_max_indevice_grad_add': "8",
|
'FLAGS_max_inplace_grad_add': 8,
|
||||||
"FLAGS_cudnn_batchnorm_spatial_persistent": "1",
|
|
||||||
}
|
}
|
||||||
for k in AMP_RELATED_FLAGS_SETTING:
|
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1'
|
||||||
os.environ[k] = AMP_RELATED_FLAGS_SETTING[k]
|
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||||
|
|
||||||
use_xpu = global_config.get("use_xpu", False)
|
use_xpu = global_config.get("use_xpu", False)
|
||||||
use_npu = global_config.get("use_npu", False)
|
use_npu = global_config.get("use_npu", False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user