mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
new functionality to change the similarity function
This commit is contained in:
parent
67b8b5b0c1
commit
70bc1cbaaf
@ -3,6 +3,6 @@ out_dim: 64
|
||||
s: 1
|
||||
temperature: 0.5
|
||||
base_convnet: "resnet18" # one of: "resnet18 or resnet50"
|
||||
use_cosine_similarity: True
|
||||
use_cosine_similarity: False
|
||||
epochs: 40
|
||||
num_workers: 4
|
||||
num_workers: 0
|
||||
|
29
train.py
29
train.py
@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.resnet_simclr import ResNetSimCLR
|
||||
from utils import get_negative_mask, get_augmentation_transform
|
||||
from utils import get_negative_mask, get_augmentation_transform, get_similarity_function
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@ -23,14 +23,17 @@ use_cosine_similarity = config['use_cosine_similarity']
|
||||
|
||||
data_augment = get_augmentation_transform(s=config['s'], crop_size=96)
|
||||
|
||||
train_dataset = datasets.STL10('./data', split='train', download=True, transform=transforms.ToTensor())
|
||||
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=transforms.ToTensor())
|
||||
# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(),
|
||||
# target_transform=None, download=True)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=config['num_workers'], drop_last=True,
|
||||
shuffle=True)
|
||||
|
||||
# model = Encoder(out_dim=out_dim)
|
||||
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
||||
|
||||
train_gpu = False # torch.cuda.is_available()
|
||||
train_gpu = torch.cuda.is_available()
|
||||
print("Is gpu available:", train_gpu)
|
||||
|
||||
# moves the model parameters to gpu
|
||||
@ -42,9 +45,7 @@ optimizer = optim.Adam(model.parameters(), 3e-4)
|
||||
|
||||
train_writer = SummaryWriter()
|
||||
|
||||
if use_cosine_similarity:
|
||||
cos_similarity_dim1 = torch.nn.CosineSimilarity(dim=1)
|
||||
cos_similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
|
||||
sim_func_dim1, sim_func_dim2 = get_similarity_function(use_cosine_similarity)
|
||||
|
||||
# Mask to remove positive examples from the batch of negative samples
|
||||
negative_mask = get_negative_mask(batch_size)
|
||||
@ -86,13 +87,7 @@ for e in range(config['epochs']):
|
||||
# assert zis.shape == (batch_size, out_dim), "Shape not expected: " + str(zis.shape)
|
||||
# assert zjs.shape == (batch_size, out_dim), "Shape not expected: " + str(zjs.shape)
|
||||
|
||||
# positive pairs
|
||||
if use_cosine_similarity:
|
||||
l_pos = cos_similarity_dim1(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
|
||||
1)
|
||||
else:
|
||||
l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
|
||||
|
||||
l_pos = sim_func_dim1(zis, zjs).view(batch_size, 1)
|
||||
l_pos /= temperature
|
||||
# assert l_pos.shape == (batch_size, 1), "l_pos shape not valid" + str(l_pos.shape) # [N,1]
|
||||
|
||||
@ -101,13 +96,7 @@ for e in range(config['epochs']):
|
||||
loss = 0
|
||||
|
||||
for positives in [zis, zjs]:
|
||||
|
||||
if use_cosine_similarity:
|
||||
l_neg = cos_similarity_dim2(positives.view(batch_size, 1, out_dim),
|
||||
negatives.view(1, (2 * batch_size), out_dim))
|
||||
else:
|
||||
l_neg = torch.tensordot(positives.view(batch_size, 1, out_dim),
|
||||
negatives.T.view(1, out_dim, (2 * batch_size)), dims=2)
|
||||
l_neg = sim_func_dim2(positives, negatives)
|
||||
|
||||
labels = torch.zeros(batch_size, dtype=torch.long)
|
||||
if train_gpu:
|
||||
|
18
utils.py
18
utils.py
@ -21,9 +21,10 @@ def get_negative_mask(batch_size):
|
||||
|
||||
class GaussianBlur(object):
|
||||
# Implements Gaussian blur as described in the SimCLR paper
|
||||
def __init__(self, min=0.1, max=2.0, kernel_size=9):
|
||||
def __init__(self, kernel_size, min=0.1, max=2.0):
|
||||
self.min = min
|
||||
self.max = max
|
||||
# kernel size is set to be 10% of the image height/width
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def __call__(self, sample):
|
||||
@ -43,22 +44,28 @@ def get_augmentation_transform(s, crop_size):
|
||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
||||
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
|
||||
transforms.RandomResizedCrop(crop_size),
|
||||
transforms.RandomResizedCrop(size=crop_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([color_jitter], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
GaussianBlur(),
|
||||
GaussianBlur(kernel_size=int(0.1 * crop_size)),
|
||||
transforms.ToTensor()])
|
||||
return data_aug_ope
|
||||
|
||||
|
||||
def _dot_simililarity_dim1(x, y):
|
||||
v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2))
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (N, C, 1)
|
||||
# v shape: (N, 1, 1)
|
||||
v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2)) #
|
||||
return v
|
||||
|
||||
|
||||
def _dot_simililarity_dim2(x, y):
|
||||
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (1, C, 2N)
|
||||
# v shape: (N, 2N)
|
||||
return v
|
||||
|
||||
|
||||
@ -68,6 +75,9 @@ def _cosine_simililarity_dim1(x, y):
|
||||
|
||||
|
||||
def _cosine_simililarity_dim2(x, y):
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (1, 2N, C)
|
||||
# v shape: (N, 2N)
|
||||
v = cos2d(x.unsqueeze(1), y.unsqueeze(0))
|
||||
return v
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user