fix: enable amp only in training
parent
7040ce8314
commit
10c93c55d1
|
@ -20,6 +20,7 @@ Arch:
|
||||||
name: SE_ResNeXt101_32x4d
|
name: SE_ResNeXt101_32x4d
|
||||||
class_num: 1000
|
class_num: 1000
|
||||||
input_image_channel: *image_channel
|
input_image_channel: *image_channel
|
||||||
|
data_format: "NHWC"
|
||||||
|
|
||||||
# loss function config for traing/eval process
|
# loss function config for traing/eval process
|
||||||
Loss:
|
Loss:
|
||||||
|
|
|
@ -97,7 +97,7 @@ class Engine(object):
|
||||||
paddle.__version__, self.device))
|
paddle.__version__, self.device))
|
||||||
|
|
||||||
# AMP training
|
# AMP training
|
||||||
self.amp = True if "AMP" in self.config else False
|
self.amp = True if "AMP" in self.config and self.mode == "train" else False
|
||||||
if self.amp and self.config["AMP"] is not None:
|
if self.amp and self.config["AMP"] is not None:
|
||||||
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
||||||
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
||||||
|
@ -223,8 +223,11 @@ class Engine(object):
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
self.config['AMP']["level"] = "O1"
|
self.config['AMP']["level"] = "O1"
|
||||||
amp_level = "O1"
|
amp_level = "O1"
|
||||||
self.model = paddle.amp.decorate(
|
self.model, self.optimizer = paddle.amp.decorate(
|
||||||
models=self.model, level=amp_level, save_dtype='float32')
|
models=self.model,
|
||||||
|
optimizers=self.optimizer,
|
||||||
|
level=amp_level,
|
||||||
|
save_dtype='float32')
|
||||||
|
|
||||||
# for distributed
|
# for distributed
|
||||||
self.config["Global"][
|
self.config["Global"][
|
||||||
|
|
Loading…
Reference in New Issue