added l2 normalization to projection feature vectors

This commit is contained in:
Thalles 2020-02-18 13:39:23 -03:00 committed by Thalles Silva
parent 8e5c65d25e
commit 032d547731

View File

@ -7,6 +7,7 @@ import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from model import Encoder
from utils import GaussianBlur
@ -44,6 +45,7 @@ optimizer = optim.Adam(model.parameters(), 3e-4)
train_writer = SummaryWriter()
similarity = torch.nn.CosineSimilarity(dim=1)
similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
n_iter = 0
for e in range(40):
@ -81,6 +83,9 @@ for e in range(40):
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
# print(hjs.shape, zjs.shape)
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
# positive pairs
# l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
l_pos = similarity(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
@ -90,6 +95,12 @@ for e in range(40):
assert l_pos.shape == (batch_size, 1) # [N,1]
l_neg = []
#############
#############
# negatives = torch.cat([zjs, zis], dim=0)
#############
#############
for i in range(zis.shape[0]):
mask = np.ones(zjs.shape[0], dtype=bool)
mask[i] = False
@ -103,6 +114,19 @@ for e in range(40):
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
# print("l_neg.shape -->", l_neg.shape)
#############
#############
# l_negs = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives.view(1, 6, out_dim))
#
# mask = torch.ones_like(l_negs, dtype=bool)
# for i in range(l_negs.shape[0]):
# mask[i, i] = 0
# mask[i, i + l_negs.shape[0]] = 0
#
# l_negs = l_negs[mask].view(l_negs.shape[0], -1)
#############
#############
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
# print("logits.shape -->",logits.shape)