From 4f735dbde167684cce5dcd14c035e9f67b4d187d Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 29 Dec 2022 13:42:12 +0800 Subject: [PATCH] matmul use fp32 compute_type (#8733) --- tools/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/train.py b/tools/train.py index ff261e85fe..89c1c7a039 100755 --- a/tools/train.py +++ b/tools/train.py @@ -152,9 +152,10 @@ def main(config, device, logger, vdl_writer): AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } if paddle.is_compiled_with_cuda(): AMP_RELATED_FLAGS_SETTING.update({ - 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, + 'FLAGS_gemm_use_half_precision_compute_type': 0, }) - paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + paddle.set_flags(AMP_RELATED_FLAGS_SETTING) scale_loss = config["Global"].get("scale_loss", 1.0) use_dynamic_loss_scaling = config["Global"].get( "use_dynamic_loss_scaling", False)