refactor amp
parent
b2cb417842
commit
f884f28853
|
@ -242,61 +242,8 @@ class Engine(object):
|
|||
self.config["Optimizer"], self.config["Global"]["epochs"],
|
||||
self.iter_per_epoch // self.update_freq,
|
||||
[self.model, self.train_loss_func])
|
||||
|
||||
# AMP training and evaluating
|
||||
self.amp = "AMP" in self.config and self.config["AMP"] is not None
|
||||
self.amp_eval = False
|
||||
# for amp
|
||||
if self.amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
|
||||
if paddle.is_compiled_with_cuda():
|
||||
AMP_RELATED_FLAGS_SETTING.update({
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
||||
})
|
||||
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
|
||||
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
||||
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
||||
"use_dynamic_loss_scaling", False)
|
||||
self.scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=self.scale_loss,
|
||||
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
||||
|
||||
self.amp_level = self.config['AMP'].get("level", "O1")
|
||||
if self.amp_level not in ["O1", "O2"]:
|
||||
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
|
||||
logger.warning(msg)
|
||||
self.config['AMP']["level"] = "O1"
|
||||
self.amp_level = "O1"
|
||||
|
||||
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
|
||||
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
|
||||
if self.mode == "train" and self.config["Global"].get(
|
||||
"eval_during_train",
|
||||
True) and self.amp_level == "O2" and self.amp_eval == False:
|
||||
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
|
||||
logger.warning(msg)
|
||||
self.config["AMP"]["use_fp16_test"] = True
|
||||
self.amp_eval = True
|
||||
|
||||
if self.mode == "train":
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
elif self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
self.train_loss_func = paddle.amp.decorate(
|
||||
models=self.train_loss_func,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
# amp
|
||||
self._init_amp()
|
||||
|
||||
# build EMA model
|
||||
self.ema = "EMA" in self.config and self.mode == "train"
|
||||
|
@ -513,7 +460,9 @@ class Engine(object):
|
|||
batch_tensor = paddle.to_tensor(batch_data)
|
||||
|
||||
if self.amp and self.amp_eval:
|
||||
with paddle.amp.auto_cast(level=self.amp_level):
|
||||
with paddle.amp.auto_cast(
|
||||
level=self.amp_level,
|
||||
use_promote=self.use_promote):
|
||||
out = self.model(batch_tensor)
|
||||
else:
|
||||
out = self.model(batch_tensor)
|
||||
|
@ -578,6 +527,62 @@ class Engine(object):
|
|||
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
|
||||
)
|
||||
|
||||
def _init_amp(self):
|
||||
self.amp = "AMP" in self.config and self.config["AMP"] is not None
|
||||
self.amp_eval = False
|
||||
|
||||
if self.amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
|
||||
if paddle.is_compiled_with_cuda():
|
||||
AMP_RELATED_FLAGS_SETTING.update({
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
||||
})
|
||||
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
|
||||
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
||||
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
||||
"use_dynamic_loss_scaling", False)
|
||||
self.scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=self.scale_loss,
|
||||
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
||||
|
||||
self.use_promote = self.config['AMP'].get("use_promote", False)
|
||||
self.amp_level = self.config['AMP'].get("level", "O1")
|
||||
if self.amp_level not in ["O1", "O2"]:
|
||||
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
|
||||
logger.warning(msg)
|
||||
self.config['AMP']["level"] = "O1"
|
||||
self.amp_level = "O1"
|
||||
|
||||
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
|
||||
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
|
||||
if self.mode == "train" and self.config["Global"].get(
|
||||
"eval_during_train",
|
||||
True) and self.amp_level == "O2" and self.amp_eval == False:
|
||||
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
|
||||
logger.warning(msg)
|
||||
self.config["AMP"]["use_fp16_test"] = True
|
||||
self.amp_eval = True
|
||||
|
||||
if self.mode == "train":
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
elif self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
self.train_loss_func = paddle.amp.decorate(
|
||||
models=self.train_loss_func,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
|
||||
class ExportModel(TheseusLayer):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue