mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix: compatible with Paddle 2.2, 2.3, and develop.
This commit is contained in:
parent
59a3dcfc1c
commit
275945dff9
@ -98,23 +98,6 @@ class Engine(object):
|
|||||||
logger.info('train with paddle {} and device {}'.format(
|
logger.info('train with paddle {} and device {}'.format(
|
||||||
paddle.__version__, self.device))
|
paddle.__version__, self.device))
|
||||||
|
|
||||||
# AMP training and evaluating
|
|
||||||
self.amp = "AMP" in self.config and self.config["AMP"] is not None
|
|
||||||
if self.amp:
|
|
||||||
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
|
||||||
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
|
||||||
"use_dynamic_loss_scaling", False)
|
|
||||||
else:
|
|
||||||
self.scale_loss = 1.0
|
|
||||||
self.use_dynamic_loss_scaling = False
|
|
||||||
if self.amp:
|
|
||||||
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
|
|
||||||
if paddle.is_compiled_with_cuda():
|
|
||||||
AMP_RELATED_FLAGS_SETTING.update({
|
|
||||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
|
||||||
})
|
|
||||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
|
||||||
|
|
||||||
if "class_num" in config["Global"]:
|
if "class_num" in config["Global"]:
|
||||||
global_class_num = config["Global"]["class_num"]
|
global_class_num = config["Global"]["class_num"]
|
||||||
if "class_num" not in config["Arch"]:
|
if "class_num" not in config["Arch"]:
|
||||||
@ -228,27 +211,77 @@ class Engine(object):
|
|||||||
len(self.train_dataloader),
|
len(self.train_dataloader),
|
||||||
[self.model, self.train_loss_func])
|
[self.model, self.train_loss_func])
|
||||||
|
|
||||||
|
# AMP training and evaluating
|
||||||
|
self.amp = "AMP" in self.config and self.config["AMP"] is not None
|
||||||
|
self.amp_eval = False
|
||||||
# for amp
|
# for amp
|
||||||
if self.amp:
|
if self.amp:
|
||||||
|
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
AMP_RELATED_FLAGS_SETTING.update({
|
||||||
|
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
||||||
|
})
|
||||||
|
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||||
|
|
||||||
|
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
|
||||||
|
self.use_dynamic_loss_scaling = self.config["AMP"].get(
|
||||||
|
"use_dynamic_loss_scaling", False)
|
||||||
self.scaler = paddle.amp.GradScaler(
|
self.scaler = paddle.amp.GradScaler(
|
||||||
init_loss_scaling=self.scale_loss,
|
init_loss_scaling=self.scale_loss,
|
||||||
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
||||||
amp_level = self.config['AMP'].get("level", "O1")
|
|
||||||
if amp_level not in ["O1", "O2"]:
|
self.amp_level = self.config['AMP'].get("level", "O1")
|
||||||
|
if self.amp_level not in ["O1", "O2"]:
|
||||||
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
|
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
self.config['AMP']["level"] = "O1"
|
self.config['AMP']["level"] = "O1"
|
||||||
amp_level = "O1"
|
self.amp_level = "O1"
|
||||||
|
|
||||||
|
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
|
||||||
|
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
|
||||||
|
if self.config["Global"].get(
|
||||||
|
"eval_during_train",
|
||||||
|
True) and self.amp_level == "O2" and self.amp_eval == False:
|
||||||
|
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
|
||||||
|
logger.warning(msg)
|
||||||
|
self.config["AMP"]["use_fp16_test"] = True
|
||||||
|
self.amp_eval = True
|
||||||
|
|
||||||
|
# TODO(gaotingquan): to compatible with Paddle 2.2, 2.3, develop and so on.
|
||||||
|
paddle_version = sum([
|
||||||
|
int(x) * 10**(2 - i)
|
||||||
|
for i, x in enumerate(paddle.__version__.split(".")[:3])
|
||||||
|
])
|
||||||
|
# paddle version < 2.3.0 and not develop
|
||||||
|
if paddle_version < 230 and paddle_version != 0:
|
||||||
|
if self.mode == "train":
|
||||||
|
self.model, self.optimizer = paddle.amp.decorate(
|
||||||
|
models=self.model,
|
||||||
|
optimizers=self.optimizer,
|
||||||
|
level=self.amp_level,
|
||||||
|
save_dtype='float32')
|
||||||
|
elif self.amp_eval:
|
||||||
|
if self.amp_level == "O2":
|
||||||
|
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
|
||||||
|
logger.warning(msg)
|
||||||
|
self.amp_eval = False
|
||||||
|
else:
|
||||||
|
self.model, self.optimizer = paddle.amp.decorate(
|
||||||
|
models=self.model,
|
||||||
|
level=self.amp_level,
|
||||||
|
save_dtype='float32')
|
||||||
|
# paddle version >= 2.3.0 or develop
|
||||||
|
else:
|
||||||
self.model = paddle.amp.decorate(
|
self.model = paddle.amp.decorate(
|
||||||
models=self.model, level=amp_level, save_dtype='float32')
|
models=self.model,
|
||||||
# TODO(gaotingquan): to compatible with Paddle develop and 2.2
|
level=self.amp_level,
|
||||||
if isinstance(self.model, tuple):
|
save_dtype='float32')
|
||||||
self.model = self.model[0]
|
|
||||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||||
)) > 0:
|
)) > 0:
|
||||||
self.train_loss_func = paddle.amp.decorate(
|
self.train_loss_func = paddle.amp.decorate(
|
||||||
models=self.train_loss_func,
|
models=self.train_loss_func,
|
||||||
level=amp_level,
|
level=self.amp_level,
|
||||||
save_dtype='float32')
|
save_dtype='float32')
|
||||||
|
|
||||||
# for distributed
|
# for distributed
|
||||||
|
@ -32,15 +32,6 @@ def classification_eval(engine, epoch_id=0):
|
|||||||
}
|
}
|
||||||
print_batch_step = engine.config["Global"]["print_batch_step"]
|
print_batch_step = engine.config["Global"]["print_batch_step"]
|
||||||
|
|
||||||
if engine.amp:
|
|
||||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
|
||||||
if amp_level == "O2" and engine.config["AMP"].get("use_fp16_test",
|
|
||||||
False):
|
|
||||||
engine.config["AMP"]["use_fp16_test"] = True
|
|
||||||
msg = "Only support FP16 evaluation when AMP O2 is enabled."
|
|
||||||
logger.warning(msg)
|
|
||||||
amp_eval = engine.config["AMP"].get("use_fp16_test", False)
|
|
||||||
|
|
||||||
metric_key = None
|
metric_key = None
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
accum_samples = 0
|
accum_samples = 0
|
||||||
@ -67,12 +58,12 @@ def classification_eval(engine, epoch_id=0):
|
|||||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||||
|
|
||||||
# image input
|
# image input
|
||||||
if engine.amp and amp_eval:
|
if engine.amp and engine.amp_eval:
|
||||||
with paddle.amp.auto_cast(
|
with paddle.amp.auto_cast(
|
||||||
custom_black_list={
|
custom_black_list={
|
||||||
"flatten_contiguous_range", "greater_than"
|
"flatten_contiguous_range", "greater_than"
|
||||||
},
|
},
|
||||||
level=amp_level):
|
level=engine.amp_level):
|
||||||
out = engine.model(batch[0])
|
out = engine.model(batch[0])
|
||||||
else:
|
else:
|
||||||
out = engine.model(batch[0])
|
out = engine.model(batch[0])
|
||||||
@ -120,12 +111,12 @@ def classification_eval(engine, epoch_id=0):
|
|||||||
|
|
||||||
# calc loss
|
# calc loss
|
||||||
if engine.eval_loss_func is not None:
|
if engine.eval_loss_func is not None:
|
||||||
if engine.amp and amp_eval:
|
if engine.amp and engine.amp_eval:
|
||||||
with paddle.amp.auto_cast(
|
with paddle.amp.auto_cast(
|
||||||
custom_black_list={
|
custom_black_list={
|
||||||
"flatten_contiguous_range", "greater_than"
|
"flatten_contiguous_range", "greater_than"
|
||||||
},
|
},
|
||||||
level=amp_level):
|
level=engine.amp_level):
|
||||||
loss_dict = engine.eval_loss_func(preds, labels)
|
loss_dict = engine.eval_loss_func(preds, labels)
|
||||||
else:
|
else:
|
||||||
loss_dict = engine.eval_loss_func(preds, labels)
|
loss_dict = engine.eval_loss_func(preds, labels)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user