optimizer must be decorated when training with AMPO2
parent
dc2c8528ad
commit
a7ba6eabd2
|
@ -274,33 +274,17 @@ class Engine(object):
|
|||
self.config["AMP"]["use_fp16_test"] = True
|
||||
self.amp_eval = True
|
||||
|
||||
# TODO(gaotingquan): to compatible with different versions of Paddle
|
||||
paddle_version = paddle.__version__[:3]
|
||||
# paddle version < 2.3.0 and not develop
|
||||
if paddle_version not in ["2.3", "0.0"]:
|
||||
if self.mode == "train":
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
elif self.amp_eval:
|
||||
if self.amp_level == "O2":
|
||||
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
|
||||
logger.warning(msg)
|
||||
self.amp_eval = False
|
||||
else:
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
# paddle version >= 2.3.0 or develop
|
||||
else:
|
||||
if self.mode == "train" or self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
if self.mode == "train":
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
elif self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
|
|
Loading…
Reference in New Issue