diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index ad5c584f0..bdcd4c4cc 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -51,6 +51,11 @@ class Engine(object): self.config = config self.eval_mode = self.config["Global"].get("eval_mode", "classification") + if "Head" in self.config["Arch"]: + self.is_rec = True + else: + self.is_rec = False + # init logger self.output_dir = self.config['Global']['output_dir'] log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 9e36a063e..73f225087 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step): def forward(trainer, batch): - if trainer.eval_mode == "classification": + if not trainer.is_rec: return trainer.model(batch[0]) else: return trainer.model(batch[0], batch[1])