Revert "rm codes for compatibility with old version"

This reverts commit 6e77bd6cd5.
pull/2701/head
Tingquan Gao 2023-03-14 16:16:40 +08:00
parent 56e8c5a992
commit 7243f1429b
4 changed files with 42 additions and 11 deletions

View File

@ -221,6 +221,15 @@ class DataIterator(object):
def build_dataloader(config, mode):
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num})
config["DataLoader"].update({"epochs": config["Global"]["epochs"]})

View File

@ -412,18 +412,33 @@ class Engine(object):
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "2.4", "0.0"]:
msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
logger.error(msg)
raise Exception(msg)
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if paddle_version not in ["2.3", "0.0"]:
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:

View File

@ -4,7 +4,7 @@ import paddle
import paddle.nn as nn
from ppcls.utils import logger
from .celoss import CELoss
from .celoss import CELoss, MixCELoss
from .googlenetloss import GoogLeNetLoss
from .centerloss import CenterLoss
from .contrasiveloss import ContrastiveLoss

View File

@ -66,3 +66,10 @@ class CELoss(nn.Layer):
soft_label=soft_label,
reduction=self.reduction)
return {"CELoss": loss}
class MixCELoss(object):
def __init__(self, *args, **kwargs):
msg = "\"MixCELos\" is deprecated, please use \"CELoss\" instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)