fix rec forward bug

This commit is contained in:
dongshuilong 2021-09-02 07:42:22 +00:00
parent 81fcc9cb72
commit d49657ad08
2 changed files with 6 additions and 1 deletions

View File

@ -51,6 +51,11 @@ class Engine(object):
self.config = config
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
if "Head" in self.config["Arch"]:
self.is_rec = True
else:
self.is_rec = False
# init logger
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],

View File

@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step):
def forward(trainer, batch):
if trainer.eval_mode == "classification":
if not trainer.is_rec:
return trainer.model(batch[0])
else:
return trainer.model(batch[0], batch[1])