rm codes for compatibility with old version
parent
f525cea006
commit
6e77bd6cd5
|
@ -221,15 +221,6 @@ class DataIterator(object):
|
|||
|
||||
|
||||
def build_dataloader(config, mode):
|
||||
if "class_num" in config["Global"]:
|
||||
global_class_num = config["Global"]["class_num"]
|
||||
if "class_num" not in config["Arch"]:
|
||||
config["Arch"]["class_num"] = global_class_num
|
||||
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
|
||||
else:
|
||||
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
|
||||
logger.warning(msg)
|
||||
|
||||
class_num = config["Arch"].get("class_num", None)
|
||||
config["DataLoader"].update({"class_num": class_num})
|
||||
config["DataLoader"].update({"epochs": config["Global"]["epochs"]})
|
||||
|
|
|
@ -412,33 +412,18 @@ class Engine(object):
|
|||
self.config["AMP"]["use_fp16_test"] = True
|
||||
self.amp_eval = True
|
||||
|
||||
# TODO(gaotingquan): to compatible with different versions of Paddle
|
||||
paddle_version = paddle.__version__[:3]
|
||||
# paddle version < 2.3.0 and not develop
|
||||
if paddle_version not in ["2.3", "0.0"]:
|
||||
if self.mode == "train":
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
optimizers=self.optimizer,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
elif self.amp_eval:
|
||||
if self.amp_level == "O2":
|
||||
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
|
||||
logger.warning(msg)
|
||||
self.amp_eval = False
|
||||
else:
|
||||
self.model, self.optimizer = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
# paddle version >= 2.3.0 or develop
|
||||
else:
|
||||
if self.mode == "train" or self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
if paddle_version not in ["2.3", "2.4", "0.0"]:
|
||||
msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
if self.mode == "train" or self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
|
|
|
@ -4,7 +4,7 @@ import paddle
|
|||
import paddle.nn as nn
|
||||
from ppcls.utils import logger
|
||||
|
||||
from .celoss import CELoss, MixCELoss
|
||||
from .celoss import CELoss
|
||||
from .googlenetloss import GoogLeNetLoss
|
||||
from .centerloss import CenterLoss
|
||||
from .contrasiveloss import ContrastiveLoss
|
||||
|
|
|
@ -66,10 +66,3 @@ class CELoss(nn.Layer):
|
|||
soft_label=soft_label,
|
||||
reduction=self.reduction)
|
||||
return {"CELoss": loss}
|
||||
|
||||
|
||||
class MixCELoss(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
msg = "\"MixCELos\" is deprecated, please use \"CELoss\" instead."
|
||||
logger.error(DeprecationWarning(msg))
|
||||
raise DeprecationWarning(msg)
|
||||
|
|
Loading…
Reference in New Issue