mirror of https://github.com/JDAI-CV/DCL.git
Update eval_model.py
parent
6d78afc333
commit
6203db874e
|
@ -17,7 +17,7 @@ import pdb
|
||||||
def dt():
|
def dt():
|
||||||
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
||||||
|
|
||||||
def eval_turn(model, data_loader, val_version, epoch_num, log_file):
|
def eval_turn(Config, model, data_loader, val_version, epoch_num, log_file):
|
||||||
|
|
||||||
model.train(False)
|
model.train(False)
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def eval_turn(model, data_loader, val_version, epoch_num, log_file):
|
||||||
val_loss_recorder.update(loss)
|
val_loss_recorder.update(loss)
|
||||||
val_celoss_recorder.update(ce_loss)
|
val_celoss_recorder.update(ce_loss)
|
||||||
|
|
||||||
if outputs[1].size(1) != 2:
|
if Config.use_dcl and Config.cls_2xmul:
|
||||||
outputs_pred = outputs[0] + outputs[1][:,0:num_cls] + outputs[1][:,num_cls:2*num_cls]
|
outputs_pred = outputs[0] + outputs[1][:,0:num_cls] + outputs[1][:,num_cls:2*num_cls]
|
||||||
else:
|
else:
|
||||||
outputs_pred = outputs[0]
|
outputs_pred = outputs[0]
|
||||||
|
|
Loading…
Reference in New Issue