From 032d547731e11f7573d021446c783b7711905f74 Mon Sep 17 00:00:00 2001 From: Thalles Date: Tue, 18 Feb 2020 13:39:23 -0300 Subject: [PATCH] added l2 normalization to projection feature vectors --- train.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/train.py b/train.py index e2f1079..c19c126 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision import datasets from torch.utils.tensorboard import SummaryWriter +import torch.nn.functional as F from model import Encoder from utils import GaussianBlur @@ -44,6 +45,7 @@ optimizer = optim.Adam(model.parameters(), 3e-4) train_writer = SummaryWriter() similarity = torch.nn.CosineSimilarity(dim=1) +similarity_dim2 = torch.nn.CosineSimilarity(dim=2) n_iter = 0 for e in range(40): @@ -81,6 +83,9 @@ for e in range(40): train_writer.add_histogram("xj_latent", zjs, global_step=n_iter) # print(hjs.shape, zjs.shape) + zis = F.normalize(zis, dim=1) + zjs = F.normalize(zjs, dim=1) + # positive pairs # l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1) l_pos = similarity(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size, @@ -90,6 +95,12 @@ for e in range(40): assert l_pos.shape == (batch_size, 1) # [N,1] l_neg = [] + ############# + ############# + # negatives = torch.cat([zjs, zis], dim=0) + ############# + ############# + for i in range(zis.shape[0]): mask = np.ones(zjs.shape[0], dtype=bool) mask[i] = False @@ -103,6 +114,19 @@ for e in range(40): assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape) # print("l_neg.shape -->", l_neg.shape) + ############# + ############# + # l_negs = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives.view(1, 6, out_dim)) + # + # mask = torch.ones_like(l_negs, dtype=bool) + # for i in range(l_negs.shape[0]): + # mask[i, i] = 0 + # mask[i, i + l_negs.shape[0]] = 0 + # + # l_negs = l_negs[mask].view(l_negs.shape[0], -1) + ############# + ############# + logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1] # print("logits.shape -->",logits.shape)