mirror of https://github.com/JDAI-CV/DCL.git
Update train_model.py
parent
c7fb82cc1e
commit
6d78afc333
|
@ -132,11 +132,11 @@ def train(Config,
|
|||
print('step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'.format(step, train_epoch_step, 1.0*step/train_epoch_step, epoch, train_loss_recorder.get_val()), flush=True)
|
||||
print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)
|
||||
if eval_train_flag:
|
||||
trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(model, data_loader['trainval'], 'trainval', epoch, log_file)
|
||||
trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(Config, model, data_loader['trainval'], 'trainval', epoch, log_file)
|
||||
if abs(trainval_acc1 - trainval_acc3) < 0.01:
|
||||
eval_train_flag = False
|
||||
|
||||
val_acc1, val_acc2, val_acc3 = eval_turn(model, data_loader['val'], 'val', epoch, log_file)
|
||||
val_acc1, val_acc2, val_acc3 = eval_turn(Config, model, data_loader['val'], 'val', epoch, log_file)
|
||||
|
||||
save_path = os.path.join(save_dir, 'weights_%d_%d_%.4f_%.4f.pth'%(epoch, batch_cnt, val_acc1, val_acc3))
|
||||
torch.cuda.synchronize()
|
||||
|
|
Loading…
Reference in New Issue