From 6d78afc333aa0e6d67c03fd407b5b34f6002f787 Mon Sep 17 00:00:00 2001 From: SaiYiKi <9478838+akira-l@users.noreply.github.com> Date: Tue, 23 Jul 2019 23:15:57 +0800 Subject: [PATCH] Update train_model.py --- utils/train_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/train_model.py b/utils/train_model.py index 2dfd5c1..8695744 100644 --- a/utils/train_model.py +++ b/utils/train_model.py @@ -132,11 +132,11 @@ def train(Config, print('step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'.format(step, train_epoch_step, 1.0*step/train_epoch_step, epoch, train_loss_recorder.get_val()), flush=True) print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True) if eval_train_flag: - trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(model, data_loader['trainval'], 'trainval', epoch, log_file) + trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(Config, model, data_loader['trainval'], 'trainval', epoch, log_file) if abs(trainval_acc1 - trainval_acc3) < 0.01: eval_train_flag = False - val_acc1, val_acc2, val_acc3 = eval_turn(model, data_loader['val'], 'val', epoch, log_file) + val_acc1, val_acc2, val_acc3 = eval_turn(Config, model, data_loader['val'], 'val', epoch, log_file) save_path = os.path.join(save_dir, 'weights_%d_%d_%.4f_%.4f.pth'%(epoch, batch_cnt, val_acc1, val_acc3)) torch.cuda.synchronize()