diff --git a/config.yaml b/config.yaml index abe94d7..e1c8b31 100644 --- a/config.yaml +++ b/config.yaml @@ -3,6 +3,6 @@ out_dim: 64 s: 1 temperature: 0.5 base_convnet: "resnet18" # one of: "resnet18 or resnet50" -use_cosine_similarity: True +use_cosine_similarity: False epochs: 40 -num_workers: 4 +num_workers: 0 diff --git a/train.py b/train.py index 0598299..902c0fd 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter import torch.nn.functional as F from models.resnet_simclr import ResNetSimCLR -from utils import get_negative_mask, get_augmentation_transform +from utils import get_negative_mask, get_augmentation_transform, get_similarity_function torch.manual_seed(0) @@ -23,14 +23,17 @@ use_cosine_similarity = config['use_cosine_similarity'] data_augment = get_augmentation_transform(s=config['s'], crop_size=96) -train_dataset = datasets.STL10('./data', split='train', download=True, transform=transforms.ToTensor()) +train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=transforms.ToTensor()) +# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(), +# target_transform=None, download=True) + train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=config['num_workers'], drop_last=True, shuffle=True) # model = Encoder(out_dim=out_dim) model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim) -train_gpu = False # torch.cuda.is_available() +train_gpu = torch.cuda.is_available() print("Is gpu available:", train_gpu) # moves the model parameters to gpu @@ -42,9 +45,7 @@ optimizer = optim.Adam(model.parameters(), 3e-4) train_writer = SummaryWriter() -if use_cosine_similarity: - cos_similarity_dim1 = torch.nn.CosineSimilarity(dim=1) - cos_similarity_dim2 = torch.nn.CosineSimilarity(dim=2) +sim_func_dim1, sim_func_dim2 = get_similarity_function(use_cosine_similarity) # Mask to remove positive examples from the batch of negative samples negative_mask = get_negative_mask(batch_size) @@ -86,13 +87,7 @@ for e in range(config['epochs']): # assert zis.shape == (batch_size, out_dim), "Shape not expected: " + str(zis.shape) # assert zjs.shape == (batch_size, out_dim), "Shape not expected: " + str(zjs.shape) - # positive pairs - if use_cosine_similarity: - l_pos = cos_similarity_dim1(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size, - 1) - else: - l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1) - + l_pos = sim_func_dim1(zis, zjs).view(batch_size, 1) l_pos /= temperature # assert l_pos.shape == (batch_size, 1), "l_pos shape not valid" + str(l_pos.shape) # [N,1] @@ -101,13 +96,7 @@ for e in range(config['epochs']): loss = 0 for positives in [zis, zjs]: - - if use_cosine_similarity: - l_neg = cos_similarity_dim2(positives.view(batch_size, 1, out_dim), - negatives.view(1, (2 * batch_size), out_dim)) - else: - l_neg = torch.tensordot(positives.view(batch_size, 1, out_dim), - negatives.T.view(1, out_dim, (2 * batch_size)), dims=2) + l_neg = sim_func_dim2(positives, negatives) labels = torch.zeros(batch_size, dtype=torch.long) if train_gpu: diff --git a/utils.py b/utils.py index 7b9b3cd..2450d83 100644 --- a/utils.py +++ b/utils.py @@ -21,9 +21,10 @@ def get_negative_mask(batch_size): class GaussianBlur(object): # Implements Gaussian blur as described in the SimCLR paper - def __init__(self, min=0.1, max=2.0, kernel_size=9): + def __init__(self, kernel_size, min=0.1, max=2.0): self.min = min self.max = max + # kernel size is set to be 10% of the image height/width self.kernel_size = kernel_size def __call__(self, sample): @@ -43,22 +44,28 @@ def get_augmentation_transform(s, crop_size): # get a set of data augmentation transformations as described in the SimCLR paper. color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) data_aug_ope = transforms.Compose([transforms.ToPILImage(), - transforms.RandomResizedCrop(crop_size), + transforms.RandomResizedCrop(size=crop_size), transforms.RandomHorizontalFlip(), transforms.RandomApply([color_jitter], p=0.8), transforms.RandomGrayscale(p=0.2), - GaussianBlur(), + GaussianBlur(kernel_size=int(0.1 * crop_size)), transforms.ToTensor()]) return data_aug_ope def _dot_simililarity_dim1(x, y): - v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2)) + # x shape: (N, 1, C) + # y shape: (N, C, 1) + # v shape: (N, 1, 1) + v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2)) # return v def _dot_simililarity_dim2(x, y): v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) + # x shape: (N, 1, C) + # y shape: (1, C, 2N) + # v shape: (N, 2N) return v @@ -68,6 +75,9 @@ def _cosine_simililarity_dim1(x, y): def _cosine_simililarity_dim2(x, y): + # x shape: (N, 1, C) + # y shape: (1, 2N, C) + # v shape: (N, 2N) v = cos2d(x.unsqueeze(1), y.unsqueeze(0)) return v