fix: enable amp only in training

pull/1646/head
gaotingquan 2022-01-11 14:07:09 +00:00 committed by Tingquan Gao
parent 7040ce8314
commit 10c93c55d1
2 changed files with 7 additions and 3 deletions

View File

@ -20,6 +20,7 @@ Arch:
name: SE_ResNeXt101_32x4d
class_num: 1000
input_image_channel: *image_channel
data_format: "NHWC"
# loss function config for traing/eval process
Loss:

View File

@ -97,7 +97,7 @@ class Engine(object):
paddle.__version__, self.device))
# AMP training
self.amp = True if "AMP" in self.config else False
self.amp = True if "AMP" in self.config and self.mode == "train" else False
if self.amp and self.config["AMP"] is not None:
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
@ -223,8 +223,11 @@ class Engine(object):
logger.warning(msg)
self.config['AMP']["level"] = "O1"
amp_level = "O1"
self.model = paddle.amp.decorate(
models=self.model, level=amp_level, save_dtype='float32')
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=amp_level,
save_dtype='float32')
# for distributed
self.config["Global"][