fix export bug

pull/3262/head
zhangyubo0722 2024-09-25 17:22:56 +00:00 committed by Tingquan Gao
parent 7f35c77027
commit 542b320008
2 changed files with 6 additions and 3 deletions

View File

@ -548,7 +548,7 @@ class Engine(object):
"use_multilabel",
False) or "ATTRMetric" in self.config["Metric"]["Eval"][0]
model = self.model_ema.module if self.ema else self.model
if isinstance(self.model, paddle.DataParallel):
if hasattr(model, '_layers'):
model = copy.deepcopy(model._layers)
else:
model = copy.deepcopy(model)

View File

@ -63,8 +63,11 @@ def update_train_results(config,
train_results = {}
train_results["model_name"] = config["Global"].get("pdx_model_name",
None)
train_results["label_dict"] = config["Infer"]["PostProcess"][
"class_id_map_file"]
if config.get("infer", None):
train_results["label_dict"] = config["Infer"]["PostProcess"].get(
"class_id_map_file", "")
else:
train_results["label_dict"] = ""
train_results["train_log"] = "train.log"
train_results["visualdl_log"] = ""
train_results["config"] = "config.yaml"