Merge pull request #1200 from RainFrost1/develop

fix rec forward bug
pull/1209/head
Walter 2021-09-03 10:32:32 +08:00 committed by GitHub
commit 681ef5d186
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 1 deletions

View File

@ -51,6 +51,11 @@ class Engine(object):
self.config = config
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
if "Head" in self.config["Arch"]:
self.is_rec = True
else:
self.is_rec = False
# init logger
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],

View File

@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step):
def forward(trainer, batch):
if trainer.eval_mode == "classification":
if not trainer.is_rec:
return trainer.model(batch[0])
else:
return trainer.model(batch[0], batch[1])