diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 9ff807f7b..043b3eceb 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -116,7 +116,8 @@ class Engine(object): if self.mode == 'train': self.train_dataloader = build_dataloader( self.config["DataLoader"], "Train", self.device, self.use_dali) - if self.mode in ["train", "eval"]: + if self.mode == "eval" or (self.mode == "train" and + self.config["Global"]["eval_during_train"]): if self.eval_mode == "classification": self.eval_dataloader = build_dataloader( self.config["DataLoader"], "Eval", self.device, @@ -140,7 +141,8 @@ class Engine(object): if self.mode == "train": loss_info = self.config["Loss"]["Train"] self.train_loss_func = build_loss(loss_info) - if self.mode in ["train", "eval"]: + if self.mode == "eval" or (self.mode == "train" and + self.config["Global"]["eval_during_train"]): loss_config = self.config.get("Loss", None) if loss_config is not None: loss_config = loss_config.get("Eval") @@ -163,7 +165,8 @@ class Engine(object): else: self.train_metric_func = None - if self.mode in ["train", "eval"]: + if self.mode == "eval" or (self.mode == "train" and + self.config["Global"]["eval_during_train"]): metric_config = self.config.get("Metric") if self.eval_mode == "classification": if metric_config is not None: