diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index fc01de94c..7f04221c8 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -113,6 +113,14 @@ class Engine(object): } paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + if "class_num" in config["Global"]: + global_class_num = config["Global"]["class_num"] + if "class_num" not in config["Arch"]: + config["Arch"]["class_num"] = global_class_num + msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}." + else: + msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored." + logger.warning(msg) #TODO(gaotingquan): support rec class_num = config["Arch"].get("class_num", None) self.config["DataLoader"].update({"class_num": class_num})