diff --git a/tools/train.py b/tools/train.py index ff261e85fe..89c1c7a039 100755 --- a/tools/train.py +++ b/tools/train.py @@ -152,9 +152,10 @@ def main(config, device, logger, vdl_writer): AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } if paddle.is_compiled_with_cuda(): AMP_RELATED_FLAGS_SETTING.update({ - 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, + 'FLAGS_gemm_use_half_precision_compute_type': 0, }) - paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + paddle.set_flags(AMP_RELATED_FLAGS_SETTING) scale_loss = config["Global"].get("scale_loss", 1.0) use_dynamic_loss_scaling = config["Global"].get( "use_dynamic_loss_scaling", False)