diff --git a/train.py b/train.py index 690fc7e..38b61d2 100644 --- a/train.py +++ b/train.py @@ -93,6 +93,7 @@ if not os.path.exists(model_checkpoints_folder): os.makedirs(model_checkpoints_folder) n_iter = 0 +valid_n_iter = 0 best_valid_loss = np.inf for epoch_counter in range(config['epochs']): @@ -115,18 +116,24 @@ for epoch_counter in range(config['epochs']): # validation steps with torch.no_grad(): model.eval() - for (xis, xjs), _ in valid_loader: + + valid_loss = 0.0 + for counter, ((xis, xjs), _) in enumerate(valid_loader): if train_gpu: xis = xis.cuda() xjs = xjs.cuda() + loss = (step(xis, xjs)) + valid_loss += loss.item() - loss = step(xis, xjs) + valid_loss /= counter - if loss < best_valid_loss: - # save the model weights - best_valid_loss = loss - torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) - train_writer.add_scalar('validation_loss', loss, global_step=epoch_counter) + if valid_loss < best_valid_loss: + # save the model weights + best_valid_loss = valid_loss + torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) + + train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) + valid_n_iter += 1 model.train() diff --git a/utils.py b/utils.py index 4253296..36d4f89 100644 --- a/utils.py +++ b/utils.py @@ -4,8 +4,7 @@ from torch.utils.data import DataLoader from torch.utils.data.sampler import SubsetRandomSampler np.random.seed(0) -cos1d = torch.nn.CosineSimilarity(dim=1) -cos2d = torch.nn.CosineSimilarity(dim=2) +cosine_similarity = torch.nn.CosineSimilarity(dim=-1) def get_train_validation_data_loaders(train_dataset, config): @@ -25,7 +24,7 @@ def get_train_validation_data_loaders(train_dataset, config): valid_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=valid_sampler, num_workers=config['num_workers'], - drop_last=False) + drop_last=True) return train_loader, valid_loader @@ -57,7 +56,7 @@ def _dot_simililarity_dim2(x, y): def _cosine_simililarity_dim1(x, y): - v = cos1d(x, y) + v = cosine_similarity(x, y) return v @@ -65,7 +64,7 @@ def _cosine_simililarity_dim2(x, y): # x shape: (N, 1, C) # y shape: (1, 2N, C) # v shape: (N, 2N) - v = cos2d(x.unsqueeze(1), y.unsqueeze(0)) + v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) return v