mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added l2 normalization to projection feature vectors
This commit is contained in:
parent
8e5c65d25e
commit
032d547731
24
train.py
24
train.py
@ -7,6 +7,7 @@ import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import Encoder
|
||||
from utils import GaussianBlur
|
||||
@ -44,6 +45,7 @@ optimizer = optim.Adam(model.parameters(), 3e-4)
|
||||
train_writer = SummaryWriter()
|
||||
|
||||
similarity = torch.nn.CosineSimilarity(dim=1)
|
||||
similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
|
||||
|
||||
n_iter = 0
|
||||
for e in range(40):
|
||||
@ -81,6 +83,9 @@ for e in range(40):
|
||||
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
||||
# print(hjs.shape, zjs.shape)
|
||||
|
||||
zis = F.normalize(zis, dim=1)
|
||||
zjs = F.normalize(zjs, dim=1)
|
||||
|
||||
# positive pairs
|
||||
# l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
|
||||
l_pos = similarity(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
|
||||
@ -90,6 +95,12 @@ for e in range(40):
|
||||
assert l_pos.shape == (batch_size, 1) # [N,1]
|
||||
l_neg = []
|
||||
|
||||
#############
|
||||
#############
|
||||
# negatives = torch.cat([zjs, zis], dim=0)
|
||||
#############
|
||||
#############
|
||||
|
||||
for i in range(zis.shape[0]):
|
||||
mask = np.ones(zjs.shape[0], dtype=bool)
|
||||
mask[i] = False
|
||||
@ -103,6 +114,19 @@ for e in range(40):
|
||||
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
|
||||
# print("l_neg.shape -->", l_neg.shape)
|
||||
|
||||
#############
|
||||
#############
|
||||
# l_negs = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives.view(1, 6, out_dim))
|
||||
#
|
||||
# mask = torch.ones_like(l_negs, dtype=bool)
|
||||
# for i in range(l_negs.shape[0]):
|
||||
# mask[i, i] = 0
|
||||
# mask[i, i + l_negs.shape[0]] = 0
|
||||
#
|
||||
# l_negs = l_negs[mask].view(l_negs.shape[0], -1)
|
||||
#############
|
||||
#############
|
||||
|
||||
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
||||
# print("logits.shape -->",logits.shape)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user