From 583c88b5dca7c3eaa8d283968b97f2cccc4e2a2c Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Thu, 21 Oct 2021 02:39:27 +0000 Subject: [PATCH] fix train without eval bug --- ppcls/engine/engine.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 9ff807f7b..043b3eceb 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -116,7 +116,8 @@ class Engine(object): if self.mode == 'train': self.train_dataloader = build_dataloader( self.config["DataLoader"], "Train", self.device, self.use_dali) - if self.mode in ["train", "eval"]: + if self.mode == "eval" or (self.mode == "train" and + self.config["Global"]["eval_during_train"]): if self.eval_mode == "classification": self.eval_dataloader = build_dataloader( self.config["DataLoader"], "Eval", self.device, @@ -140,7 +141,8 @@ class Engine(object): if self.mode == "train": loss_info = self.config["Loss"]["Train"] self.train_loss_func = build_loss(loss_info) - if self.mode in ["train", "eval"]: + if self.mode == "eval" or (self.mode == "train" and + self.config["Global"]["eval_during_train"]): loss_config = self.config.get("Loss", None) if loss_config is not None: loss_config = loss_config.get("Eval") @@ -163,7 +165,8 @@ class Engine(object): else: self.train_metric_func = None - if self.mode in ["train", "eval"]: + if self.mode == "eval" or (self.mode == "train" and + self.config["Global"]["eval_during_train"]): metric_config = self.config.get("Metric") if self.eval_mode == "classification": if metric_config is not None: