diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 62d59d936..cd5a22027 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -548,12 +548,15 @@ class Engine(object): "use_multilabel", False) or "ATTRMetric" in self.config["Metric"]["Eval"][0] model = ExportModel(self.config["Arch"], self.model, use_multilabel) - if self.config["Global"]["pretrained_model"].startswith("http"): - load_dygraph_pretrain_from_url( - model.base_model, self.config["Global"]["pretrained_model"]) - else: - load_dygraph_pretrain(model.base_model, - self.config["Global"]["pretrained_model"]) + if self.config["Global"]["pretrained_model"] is not None: + if self.config["Global"]["pretrained_model"].startswith("http"): + load_dygraph_pretrain_from_url( + model.base_model, + self.config["Global"]["pretrained_model"]) + else: + load_dygraph_pretrain( + model.base_model, + self.config["Global"]["pretrained_model"]) model.eval()