added l2 normalization to projection feature vectors

This commit is contained in:
Thalles 2020-02-18 13:39:23 -03:00 committed by Thalles Silva
parent 8e5c65d25e
commit 032d547731

View File

@ -7,6 +7,7 @@ import torchvision.transforms as transforms
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import datasets from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from model import Encoder from model import Encoder
from utils import GaussianBlur from utils import GaussianBlur
@ -44,6 +45,7 @@ optimizer = optim.Adam(model.parameters(), 3e-4)
train_writer = SummaryWriter() train_writer = SummaryWriter()
similarity = torch.nn.CosineSimilarity(dim=1) similarity = torch.nn.CosineSimilarity(dim=1)
similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
n_iter = 0 n_iter = 0
for e in range(40): for e in range(40):
@ -81,6 +83,9 @@ for e in range(40):
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter) train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
# print(hjs.shape, zjs.shape) # print(hjs.shape, zjs.shape)
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
# positive pairs # positive pairs
# l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1) # l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
l_pos = similarity(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size, l_pos = similarity(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
@ -90,6 +95,12 @@ for e in range(40):
assert l_pos.shape == (batch_size, 1) # [N,1] assert l_pos.shape == (batch_size, 1) # [N,1]
l_neg = [] l_neg = []
#############
#############
# negatives = torch.cat([zjs, zis], dim=0)
#############
#############
for i in range(zis.shape[0]): for i in range(zis.shape[0]):
mask = np.ones(zjs.shape[0], dtype=bool) mask = np.ones(zjs.shape[0], dtype=bool)
mask[i] = False mask[i] = False
@ -103,6 +114,19 @@ for e in range(40):
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape) assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
# print("l_neg.shape -->", l_neg.shape) # print("l_neg.shape -->", l_neg.shape)
#############
#############
# l_negs = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives.view(1, 6, out_dim))
#
# mask = torch.ones_like(l_negs, dtype=bool)
# for i in range(l_negs.shape[0]):
# mask[i, i] = 0
# mask[i, i + l_negs.shape[0]] = 0
#
# l_negs = l_negs[mask].view(l_negs.shape[0], -1)
#############
#############
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1] logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
# print("logits.shape -->",logits.shape) # print("logits.shape -->",logits.shape)