commit
681ef5d186
|
@ -51,6 +51,11 @@ class Engine(object):
|
|||
self.config = config
|
||||
self.eval_mode = self.config["Global"].get("eval_mode",
|
||||
"classification")
|
||||
if "Head" in self.config["Arch"]:
|
||||
self.is_rec = True
|
||||
else:
|
||||
self.is_rec = False
|
||||
|
||||
# init logger
|
||||
self.output_dir = self.config['Global']['output_dir']
|
||||
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
|
||||
|
|
|
@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step):
|
|||
|
||||
|
||||
def forward(trainer, batch):
|
||||
if trainer.eval_mode == "classification":
|
||||
if not trainer.is_rec:
|
||||
return trainer.model(batch[0])
|
||||
else:
|
||||
return trainer.model(batch[0], batch[1])
|
||||
|
|
Loading…
Reference in New Issue