mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added tensorboard support
This commit is contained in:
parent
6e555f1f1c
commit
67b8b5b0c1
13
train.py
13
train.py
@ -21,7 +21,7 @@ out_dim = config['out_dim']
|
|||||||
temperature = config['temperature']
|
temperature = config['temperature']
|
||||||
use_cosine_similarity = config['use_cosine_similarity']
|
use_cosine_similarity = config['use_cosine_similarity']
|
||||||
|
|
||||||
data_augment = get_augmentation_transform(s=config['s'])
|
data_augment = get_augmentation_transform(s=config['s'], crop_size=96)
|
||||||
|
|
||||||
train_dataset = datasets.STL10('./data', split='train', download=True, transform=transforms.ToTensor())
|
train_dataset = datasets.STL10('./data', split='train', download=True, transform=transforms.ToTensor())
|
||||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=config['num_workers'], drop_last=True,
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=config['num_workers'], drop_last=True,
|
||||||
@ -30,10 +30,10 @@ train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=conf
|
|||||||
# model = Encoder(out_dim=out_dim)
|
# model = Encoder(out_dim=out_dim)
|
||||||
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
||||||
|
|
||||||
train_gpu = torch.cuda.is_available()
|
train_gpu = False # torch.cuda.is_available()
|
||||||
print("Is gpu available:", train_gpu)
|
print("Is gpu available:", train_gpu)
|
||||||
|
|
||||||
# moves the model paramemeters to gpu
|
# moves the model parameters to gpu
|
||||||
if train_gpu:
|
if train_gpu:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
@ -103,12 +103,11 @@ for e in range(config['epochs']):
|
|||||||
for positives in [zis, zjs]:
|
for positives in [zis, zjs]:
|
||||||
|
|
||||||
if use_cosine_similarity:
|
if use_cosine_similarity:
|
||||||
negatives = negatives.view(1, (2 * batch_size), out_dim)
|
l_neg = cos_similarity_dim2(positives.view(batch_size, 1, out_dim),
|
||||||
l_neg = cos_similarity_dim2(positives.view(batch_size, 1, out_dim), negatives)
|
negatives.view(1, (2 * batch_size), out_dim))
|
||||||
else:
|
else:
|
||||||
l_neg = torch.tensordot(positives.view(batch_size, 1, out_dim),
|
l_neg = torch.tensordot(positives.view(batch_size, 1, out_dim),
|
||||||
negatives.T.view(1, out_dim, (2 * batch_size)),
|
negatives.T.view(1, out_dim, (2 * batch_size)), dims=2)
|
||||||
dims=2)
|
|
||||||
|
|
||||||
labels = torch.zeros(batch_size, dtype=torch.long)
|
labels = torch.zeros(batch_size, dtype=torch.long)
|
||||||
if train_gpu:
|
if train_gpu:
|
||||||
|
40
utils.py
40
utils.py
@ -4,6 +4,8 @@ import torch
|
|||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
cos1d = torch.nn.CosineSimilarity(dim=1)
|
||||||
|
cos2d = torch.nn.CosineSimilarity(dim=2)
|
||||||
|
|
||||||
|
|
||||||
def get_negative_mask(batch_size):
|
def get_negative_mask(batch_size):
|
||||||
@ -37,11 +39,11 @@ class GaussianBlur(object):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def get_augmentation_transform(s=1):
|
def get_augmentation_transform(s, crop_size):
|
||||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
||||||
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
|
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
|
||||||
transforms.RandomResizedCrop(96),
|
transforms.RandomResizedCrop(crop_size),
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.RandomHorizontalFlip(),
|
||||||
transforms.RandomApply([color_jitter], p=0.8),
|
transforms.RandomApply([color_jitter], p=0.8),
|
||||||
transforms.RandomGrayscale(p=0.2),
|
transforms.RandomGrayscale(p=0.2),
|
||||||
@ -49,11 +51,29 @@ def get_augmentation_transform(s=1):
|
|||||||
transforms.ToTensor()])
|
transforms.ToTensor()])
|
||||||
return data_aug_ope
|
return data_aug_ope
|
||||||
|
|
||||||
# if use_cosine_similarity:
|
|
||||||
# cos1d = torch.nn.CosineSimilarity(dim=1)
|
def _dot_simililarity_dim1(x, y):
|
||||||
# cos2d = torch.nn.CosineSimilarity(dim=2)
|
v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2))
|
||||||
# similarity_dim1 = lambda x, y: cos1d(x, y.unsqueeze(0))
|
return v
|
||||||
# similarity_dim2 = lambda x, y: cos2d(x, y.unsqueeze(0))
|
|
||||||
# else:
|
|
||||||
# similarity_dim1 = lambda x, y: torch.bmm(x.unsqueeze(1), y.unsqueeze(2))
|
def _dot_simililarity_dim2(x, y):
|
||||||
# similarity_dim2 = lambda x, y: torch.tensordot(x, y.T.unsqueeze(0), dims=2)
|
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_simililarity_dim1(x, y):
|
||||||
|
v = cos1d(x, y)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_simililarity_dim2(x, y):
|
||||||
|
v = cos2d(x.unsqueeze(1), y.unsqueeze(0))
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def get_similarity_function(use_cosine_similarity):
|
||||||
|
if use_cosine_similarity:
|
||||||
|
return _cosine_simililarity_dim1, _cosine_simililarity_dim2
|
||||||
|
else:
|
||||||
|
return _dot_simililarity_dim1, _dot_simililarity_dim2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user