fix: enable amp only in training
parent
7040ce8314
commit
10c93c55d1
|
@ -20,6 +20,7 @@ Arch:
|
|||
name: SE_ResNeXt101_32x4d
|
||||
class_num: 1000
|
||||
input_image_channel: *image_channel
|
||||
data_format: "NHWC"
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
|
|
|
@ -97,7 +97,7 @@ class Engine(object):
|
|||
paddle.__version__, self.device))
|
||||
|
||||
# AMP training
|
||||
self.amp = True if "AMP" in self.config else False
|
||||
self.amp = True if "AMP" in self.config and self.mode == "train" else False
|
||||
if self.amp and self.config["AMP"] is not None:
|
||||
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
||||
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
||||
|
@ -223,8 +223,11 @@ class Engine(object):
|
|||
logger.warning(msg)
|
||||
self.config['AMP']["level"] = "O1"
|
||||
amp_level = "O1"
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model, level=amp_level, save_dtype='float32')
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
# for distributed
|
||||
self.config["Global"][
|
||||
|
|
Loading…
Reference in New Issue