diff --git a/utils/eval_model.py b/utils/eval_model.py index 811ffe9..ca6bf10 100644 --- a/utils/eval_model.py +++ b/utils/eval_model.py @@ -17,7 +17,7 @@ import pdb def dt(): return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S") -def eval_turn(model, data_loader, val_version, epoch_num, log_file): +def eval_turn(Config, model, data_loader, val_version, epoch_num, log_file): model.train(False) @@ -50,7 +50,7 @@ def eval_turn(model, data_loader, val_version, epoch_num, log_file): val_loss_recorder.update(loss) val_celoss_recorder.update(ce_loss) - if outputs[1].size(1) != 2: + if Config.use_dcl and Config.cls_2xmul: outputs_pred = outputs[0] + outputs[1][:,0:num_cls] + outputs[1][:,num_cls:2*num_cls] else: outputs_pred = outputs[0]