rm codes for compatibility with old version

pull/2698/head
gaotingquan 2023-02-27 12:28:22 +00:00 committed by Wei Shengyu
parent f525cea006
commit 6e77bd6cd5
4 changed files with 11 additions and 42 deletions

View File

@ -221,15 +221,6 @@ class DataIterator(object):
def build_dataloader(config, mode):
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num})
config["DataLoader"].update({"epochs": config["Global"]["epochs"]})

View File

@ -412,33 +412,18 @@ class Engine(object):
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "0.0"]:
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if paddle_version not in ["2.3", "2.4", "0.0"]:
msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
logger.error(msg)
raise Exception(msg)
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:

View File

@ -4,7 +4,7 @@ import paddle
import paddle.nn as nn
from ppcls.utils import logger
from .celoss import CELoss, MixCELoss
from .celoss import CELoss
from .googlenetloss import GoogLeNetLoss
from .centerloss import CenterLoss
from .contrasiveloss import ContrastiveLoss

View File

@ -66,10 +66,3 @@ class CELoss(nn.Layer):
soft_label=soft_label,
reduction=self.reduction)
return {"CELoss": loss}
class MixCELoss(object):
def __init__(self, *args, **kwargs):
msg = "\"MixCELos\" is deprecated, please use \"CELoss\" instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)