mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added l2 normalization to projection feature vectors
This commit is contained in:
parent
8e5c65d25e
commit
032d547731
24
train.py
24
train.py
@ -7,6 +7,7 @@ import torchvision.transforms as transforms
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import datasets
|
from torchvision import datasets
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from model import Encoder
|
from model import Encoder
|
||||||
from utils import GaussianBlur
|
from utils import GaussianBlur
|
||||||
@ -44,6 +45,7 @@ optimizer = optim.Adam(model.parameters(), 3e-4)
|
|||||||
train_writer = SummaryWriter()
|
train_writer = SummaryWriter()
|
||||||
|
|
||||||
similarity = torch.nn.CosineSimilarity(dim=1)
|
similarity = torch.nn.CosineSimilarity(dim=1)
|
||||||
|
similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
for e in range(40):
|
for e in range(40):
|
||||||
@ -81,6 +83,9 @@ for e in range(40):
|
|||||||
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
||||||
# print(hjs.shape, zjs.shape)
|
# print(hjs.shape, zjs.shape)
|
||||||
|
|
||||||
|
zis = F.normalize(zis, dim=1)
|
||||||
|
zjs = F.normalize(zjs, dim=1)
|
||||||
|
|
||||||
# positive pairs
|
# positive pairs
|
||||||
# l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
|
# l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
|
||||||
l_pos = similarity(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
|
l_pos = similarity(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
|
||||||
@ -90,6 +95,12 @@ for e in range(40):
|
|||||||
assert l_pos.shape == (batch_size, 1) # [N,1]
|
assert l_pos.shape == (batch_size, 1) # [N,1]
|
||||||
l_neg = []
|
l_neg = []
|
||||||
|
|
||||||
|
#############
|
||||||
|
#############
|
||||||
|
# negatives = torch.cat([zjs, zis], dim=0)
|
||||||
|
#############
|
||||||
|
#############
|
||||||
|
|
||||||
for i in range(zis.shape[0]):
|
for i in range(zis.shape[0]):
|
||||||
mask = np.ones(zjs.shape[0], dtype=bool)
|
mask = np.ones(zjs.shape[0], dtype=bool)
|
||||||
mask[i] = False
|
mask[i] = False
|
||||||
@ -103,6 +114,19 @@ for e in range(40):
|
|||||||
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
|
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
|
||||||
# print("l_neg.shape -->", l_neg.shape)
|
# print("l_neg.shape -->", l_neg.shape)
|
||||||
|
|
||||||
|
#############
|
||||||
|
#############
|
||||||
|
# l_negs = similarity_dim2(zis.view(batch_size, 1, out_dim), negatives.view(1, 6, out_dim))
|
||||||
|
#
|
||||||
|
# mask = torch.ones_like(l_negs, dtype=bool)
|
||||||
|
# for i in range(l_negs.shape[0]):
|
||||||
|
# mask[i, i] = 0
|
||||||
|
# mask[i, i + l_negs.shape[0]] = 0
|
||||||
|
#
|
||||||
|
# l_negs = l_negs[mask].view(l_negs.shape[0], -1)
|
||||||
|
#############
|
||||||
|
#############
|
||||||
|
|
||||||
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
||||||
# print("logits.shape -->",logits.shape)
|
# print("logits.shape -->",logits.shape)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user