new functionality to change the similarity function

This commit is contained in:
Thalles 2020-02-29 15:21:04 -03:00
parent 67b8b5b0c1
commit 70bc1cbaaf
3 changed files with 25 additions and 26 deletions

View File

@ -3,6 +3,6 @@ out_dim: 64
s: 1
temperature: 0.5
base_convnet: "resnet18" # one of: "resnet18 or resnet50"
use_cosine_similarity: True
use_cosine_similarity: False
epochs: 40
num_workers: 4
num_workers: 0

View File

@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from models.resnet_simclr import ResNetSimCLR
from utils import get_negative_mask, get_augmentation_transform
from utils import get_negative_mask, get_augmentation_transform, get_similarity_function
torch.manual_seed(0)
@ -23,14 +23,17 @@ use_cosine_similarity = config['use_cosine_similarity']
data_augment = get_augmentation_transform(s=config['s'], crop_size=96)
train_dataset = datasets.STL10('./data', split='train', download=True, transform=transforms.ToTensor())
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=transforms.ToTensor())
# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(),
# target_transform=None, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=config['num_workers'], drop_last=True,
shuffle=True)
# model = Encoder(out_dim=out_dim)
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
train_gpu = False # torch.cuda.is_available()
train_gpu = torch.cuda.is_available()
print("Is gpu available:", train_gpu)
# moves the model parameters to gpu
@ -42,9 +45,7 @@ optimizer = optim.Adam(model.parameters(), 3e-4)
train_writer = SummaryWriter()
if use_cosine_similarity:
cos_similarity_dim1 = torch.nn.CosineSimilarity(dim=1)
cos_similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
sim_func_dim1, sim_func_dim2 = get_similarity_function(use_cosine_similarity)
# Mask to remove positive examples from the batch of negative samples
negative_mask = get_negative_mask(batch_size)
@ -86,13 +87,7 @@ for e in range(config['epochs']):
# assert zis.shape == (batch_size, out_dim), "Shape not expected: " + str(zis.shape)
# assert zjs.shape == (batch_size, out_dim), "Shape not expected: " + str(zjs.shape)
# positive pairs
if use_cosine_similarity:
l_pos = cos_similarity_dim1(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
1)
else:
l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
l_pos = sim_func_dim1(zis, zjs).view(batch_size, 1)
l_pos /= temperature
# assert l_pos.shape == (batch_size, 1), "l_pos shape not valid" + str(l_pos.shape) # [N,1]
@ -101,13 +96,7 @@ for e in range(config['epochs']):
loss = 0
for positives in [zis, zjs]:
if use_cosine_similarity:
l_neg = cos_similarity_dim2(positives.view(batch_size, 1, out_dim),
negatives.view(1, (2 * batch_size), out_dim))
else:
l_neg = torch.tensordot(positives.view(batch_size, 1, out_dim),
negatives.T.view(1, out_dim, (2 * batch_size)), dims=2)
l_neg = sim_func_dim2(positives, negatives)
labels = torch.zeros(batch_size, dtype=torch.long)
if train_gpu:

View File

@ -21,9 +21,10 @@ def get_negative_mask(batch_size):
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, min=0.1, max=2.0, kernel_size=9):
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
def __call__(self, sample):
@ -43,22 +44,28 @@ def get_augmentation_transform(s, crop_size):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
transforms.RandomResizedCrop(crop_size),
transforms.RandomResizedCrop(size=crop_size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(),
GaussianBlur(kernel_size=int(0.1 * crop_size)),
transforms.ToTensor()])
return data_aug_ope
def _dot_simililarity_dim1(x, y):
v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2))
# x shape: (N, 1, C)
# y shape: (N, C, 1)
# v shape: (N, 1, 1)
v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2)) #
return v
def _dot_simililarity_dim2(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v
@ -68,6 +75,9 @@ def _cosine_simililarity_dim1(x, y):
def _cosine_simililarity_dim2(x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = cos2d(x.unsqueeze(1), y.unsqueeze(0))
return v