fix: amp eval
parent
fea9522a69
commit
59a3dcfc1c
|
@ -99,8 +99,8 @@ class Engine(object):
|
|||
paddle.__version__, self.device))
|
||||
|
||||
# AMP training and evaluating
|
||||
self.amp = "AMP" in self.config
|
||||
if self.amp and self.config["AMP"] is not None:
|
||||
self.amp = "AMP" in self.config and self.config["AMP"] is not None
|
||||
if self.amp:
|
||||
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
||||
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
||||
"use_dynamic_loss_scaling", False)
|
||||
|
@ -228,7 +228,7 @@ class Engine(object):
|
|||
len(self.train_dataloader),
|
||||
[self.model, self.train_loss_func])
|
||||
|
||||
# for amp training
|
||||
# for amp
|
||||
if self.amp:
|
||||
self.scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=self.scale_loss,
|
||||
|
@ -239,12 +239,13 @@ class Engine(object):
|
|||
logger.warning(msg)
|
||||
self.config['AMP']["level"] = "O1"
|
||||
amp_level = "O1"
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=amp_level,
|
||||
save_dtype='float32')
|
||||
if len(self.train_loss_func.parameters()) > 0:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model, level=amp_level, save_dtype='float32')
|
||||
# TODO(gaotingquan): to compatible with Paddle develop and 2.2
|
||||
if isinstance(self.model, tuple):
|
||||
self.model = self.model[0]
|
||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
self.train_loss_func = paddle.amp.decorate(
|
||||
models=self.train_loss_func,
|
||||
level=amp_level,
|
||||
|
|
|
@ -32,6 +32,15 @@ def classification_eval(engine, epoch_id=0):
|
|||
}
|
||||
print_batch_step = engine.config["Global"]["print_batch_step"]
|
||||
|
||||
if engine.amp:
|
||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||
if amp_level == "O2" and engine.config["AMP"].get("use_fp16_test",
|
||||
False):
|
||||
engine.config["AMP"]["use_fp16_test"] = True
|
||||
msg = "Only support FP16 evaluation when AMP O2 is enabled."
|
||||
logger.warning(msg)
|
||||
amp_eval = engine.config["AMP"].get("use_fp16_test", False)
|
||||
|
||||
metric_key = None
|
||||
tic = time.time()
|
||||
accum_samples = 0
|
||||
|
@ -58,15 +67,7 @@ def classification_eval(engine, epoch_id=0):
|
|||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||
|
||||
# image input
|
||||
if engine.amp and (
|
||||
engine.config['AMP'].get("level", "O1").upper() == "O2" or
|
||||
engine.config["AMP"].get("use_fp16_test", False)):
|
||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||
|
||||
if amp_level == "O2":
|
||||
msg = "Only support FP16 evaluation when AMP O2 is enabled."
|
||||
logger.warning(msg)
|
||||
|
||||
if engine.amp and amp_eval:
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
|
@ -119,8 +120,7 @@ def classification_eval(engine, epoch_id=0):
|
|||
|
||||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
|
||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||
if engine.amp and amp_eval:
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
|
|
Loading…
Reference in New Issue