mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
parent
cf03afb10a
commit
b18b656633
@ -66,7 +66,6 @@ def amp_scaler(config):
|
|||||||
if "AMP" in config and config["AMP"]["use_amp"] is True:
|
if "AMP" in config and config["AMP"]["use_amp"] is True:
|
||||||
AMP_RELATED_FLAGS_SETTING = {
|
AMP_RELATED_FLAGS_SETTING = {
|
||||||
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
|
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
|
||||||
"FLAGS_max_inplace_grad_add": 8,
|
|
||||||
}
|
}
|
||||||
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||||
scale_loss = config["AMP"].get("scale_loss", 1.0)
|
scale_loss = config["AMP"].get("scale_loss", 1.0)
|
||||||
|
@ -131,7 +131,6 @@ def main():
|
|||||||
if use_amp:
|
if use_amp:
|
||||||
AMP_RELATED_FLAGS_SETTING = {
|
AMP_RELATED_FLAGS_SETTING = {
|
||||||
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
|
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
|
||||||
"FLAGS_max_inplace_grad_add": 8,
|
|
||||||
}
|
}
|
||||||
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||||
scale_loss = config["Global"].get("scale_loss", 1.0)
|
scale_loss = config["Global"].get("scale_loss", 1.0)
|
||||||
|
@ -181,9 +181,7 @@ def main(config, device, logger, vdl_writer, seed):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
if use_amp:
|
if use_amp:
|
||||||
AMP_RELATED_FLAGS_SETTING = {
|
AMP_RELATED_FLAGS_SETTING = {}
|
||||||
"FLAGS_max_inplace_grad_add": 8,
|
|
||||||
}
|
|
||||||
if paddle.is_compiled_with_cuda():
|
if paddle.is_compiled_with_cuda():
|
||||||
AMP_RELATED_FLAGS_SETTING.update(
|
AMP_RELATED_FLAGS_SETTING.update(
|
||||||
{
|
{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user