diff --git a/config.yaml b/config.yaml index df64b41..9ec3ca6 100644 --- a/config.yaml +++ b/config.yaml @@ -1,12 +1,19 @@ batch_size: 512 -out_dim: 256 -s: 1 -temperature: 0.5 -base_convnet: "resnet18" # one of: "resnet18 or resnet50" -use_cosine_similarity: True -epochs: 50 -num_workers: 0 -valid_size: 0.05 -eval_every_n_epochs: 2 -continue_training: Mar10_21-50-05_thallessilva +epochs: 33 +eval_every_n_epochs: 1 +fine_tune_from: 'Mar13_20-12-09_thallessilva' log_every_n_steps: 50 + +model: + out_dim: 256 + base_model: "resnet18" + +dataset: + s: 1 + input_shape: (96,96,3) + num_workers: 0 + valid_size: 0.05 + +loss: + temperature: 0.5 + use_cosine_similarity: True diff --git a/data_aug/data_transform.py b/data_aug/data_transform.py deleted file mode 100644 index aded7ef..0000000 --- a/data_aug/data_transform.py +++ /dev/null @@ -1,46 +0,0 @@ -import torchvision.transforms as transforms -import cv2 -import numpy as np - - -class DataTransform(object): - def __init__(self, transform): - self.transform = transform - - def __call__(self, sample): - xi = self.transform(sample) - xj = self.transform(sample) - return xi, xj - - -class GaussianBlur(object): - # Implements Gaussian blur as described in the SimCLR paper - def __init__(self, kernel_size, min=0.1, max=2.0): - self.min = min - self.max = max - # kernel size is set to be 10% of the image height/width - self.kernel_size = kernel_size - - def __call__(self, sample): - sample = np.array(sample) - - # blur the image with a 50% chance - prob = np.random.random_sample() - - if prob < 0.5: - sigma = (self.max - self.min) * np.random.random_sample() + self.min - sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) - - return sample - - -def get_simclr_data_transform(s, crop_size): - # get a set of data augmentation transformations as described in the SimCLR paper. - color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) - data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size), - transforms.RandomHorizontalFlip(), - transforms.RandomApply([color_jitter], p=0.8), - transforms.RandomGrayscale(p=0.2), - GaussianBlur(kernel_size=int(0.1 * crop_size)), - transforms.ToTensor()]) - return data_transforms diff --git a/data_aug/dataset_wrapper.py b/data_aug/dataset_wrapper.py new file mode 100644 index 0000000..ba4600f --- /dev/null +++ b/data_aug/dataset_wrapper.py @@ -0,0 +1,68 @@ +import numpy as np +from torch.utils.data import DataLoader +from torch.utils.data.sampler import SubsetRandomSampler +import torchvision.transforms as transforms +from data_aug.gaussian_blur import GaussianBlur +from torchvision import datasets + +np.random.seed(0) + + +class DataSetWrapper(object): + + def __init__(self, batch_size, num_workers, valid_size, input_shape, s): + self.batch_size = batch_size + self.num_workers = num_workers + self.valid_size = valid_size + self.s = s + self.input_shape = eval(input_shape) + + def get_data_loaders(self): + data_augment = self._get_simclr_pipeline_transform() + + train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, + transform=SimCLRDataTransform(data_augment)) + + train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset) + return train_loader, valid_loader + + def _get_simclr_pipeline_transform(self): + # get a set of data augmentation transformations as described in the SimCLR paper. + color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s) + data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([color_jitter], p=0.8), + transforms.RandomGrayscale(p=0.2), + GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])), + transforms.ToTensor()]) + return data_transforms + + def get_train_validation_data_loaders(self, train_dataset): + # obtain training indices that will be used for validation + num_train = len(train_dataset) + indices = list(range(num_train)) + np.random.shuffle(indices) + + split = int(np.floor(self.valid_size * num_train)) + train_idx, valid_idx = indices[split:], indices[:split] + + # define samplers for obtaining training and validation batches + train_sampler = SubsetRandomSampler(train_idx) + valid_sampler = SubsetRandomSampler(valid_idx) + + train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler, + num_workers=self.num_workers, drop_last=True, shuffle=False) + + valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler, + num_workers=self.num_workers, drop_last=True) + return train_loader, valid_loader + + +class SimCLRDataTransform(object): + def __init__(self, transform): + self.transform = transform + + def __call__(self, sample): + xi = self.transform(sample) + xj = self.transform(sample) + return xi, xj diff --git a/data_aug/gaussian_blur.py b/data_aug/gaussian_blur.py new file mode 100644 index 0000000..502e200 --- /dev/null +++ b/data_aug/gaussian_blur.py @@ -0,0 +1,25 @@ +import cv2 +import numpy as np + +np.random.seed(0) + + +class GaussianBlur(object): + # Implements Gaussian blur as described in the SimCLR paper + def __init__(self, kernel_size, min=0.1, max=2.0): + self.min = min + self.max = max + # kernel size is set to be 10% of the image height/width + self.kernel_size = kernel_size + + def __call__(self, sample): + sample = np.array(sample) + + # blur the image with a 50% chance + prob = np.random.random_sample() + + if prob < 0.5: + sigma = (self.max - self.min) * np.random.random_sample() + self.min + sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) + + return sample diff --git a/feature_eval/linear_feature_eval.ipynb b/feature_eval/linear_feature_eval.ipynb index 38681fa..7a62c1c 100644 --- a/feature_eval/linear_feature_eval.ipynb +++ b/feature_eval/linear_feature_eval.ipynb @@ -33,7 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "folder_name = 'Mar10_21-50-05_thallessilva'\n", + "folder_name = 'Mar13_20-12-09_thallessilva'\n", "checkpoints_folder = os.path.join('../runs', folder_name, 'checkpoints')\n", "config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"), Loader=yaml.FullLoader)" ] @@ -52,10 +52,13 @@ " 'temperature': 0.5,\n", " 'base_convnet': 'resnet18',\n", " 'use_cosine_similarity': True,\n", - " 'epochs': 50,\n", - " 'num_workers': 4,\n", + " 'epochs': 40,\n", + " 'num_workers': 0,\n", " 'valid_size': 0.05,\n", - " 'eval_every_n_epochs': 2}" + " 'eval_every_n_epochs': 2,\n", + " 'continue_training': 'Mar10_21-50-05_thallessilva',\n", + " 'log_every_n_steps': 50,\n", + " 'input_shape': '(96,96,3)'}" ] }, "execution_count": 4, @@ -204,17 +207,19 @@ "output_type": "stream", "text": [ "Logistic Regression feature eval\n", - "Train score: 0.4966\n", - "Test score: 0.35\n", + "Train score: 0.495\n", + "Test score: 0.34725\n", "-------------------------------\n", "KNN feature eval\n", - "Train score: 0.4036\n", - "Test score: 0.300125\n" + "Train score: 0.406\n", + "Test score: 0.297875\n" ] } ], "source": [ - "linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)" + "linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)\n", + "del X_train_pca\n", + "del X_test_pca" ] }, { @@ -226,16 +231,23 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Feature extractor: resnet18\n" + ] + }, { "data": { "text/plain": [ "" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -250,7 +262,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -263,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -290,7 +302,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -317,34 +329,20 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", - "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", - "\n", - "Increase the number of iterations (max_iter) or scale the data as shown in:\n", - " https://scikit-learn.org/stable/modules/preprocessing.html\n", - "Please also refer to the documentation for alternative solver options:\n", - " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", - " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ "Logistic Regression feature eval\n", - "Train score: 0.9628\n", - "Test score: 0.75\n", + "Train score: 0.829\n", + "Test score: 0.5445\n", "-------------------------------\n", "KNN feature eval\n", - "Train score: 0.7764\n", - "Test score: 0.709125\n" + "Train score: 0.5934\n", + "Test score: 0.47675\n" ] } ], @@ -352,12 +350,15 @@ "scaler = preprocessing.StandardScaler()\n", "scaler.fit(X_train_feature)\n", "\n", - "linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)" + "linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)\n", + "\n", + "del X_train_feature\n", + "del X_test_feature" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ diff --git a/loss/nt_xent.py b/loss/nt_xent.py new file mode 100644 index 0000000..bb846a7 --- /dev/null +++ b/loss/nt_xent.py @@ -0,0 +1,62 @@ +import torch +import numpy as np + + +class NTXentLoss(torch.nn.Module): + + def __init__(self, device, batch_size, temperature, use_cosine_similarity): + super(NTXentLoss, self).__init__() + self.batch_size = batch_size + self.temperature = temperature + self.device = device + self.softmax = torch.nn.Softmax(dim=-1) + self.mask_samples_from_same_repr = self._get_correlated_mask() + self.similarity_function = self._get_similarity_function(use_cosine_similarity) + self.labels = self._get_labels() + + def _get_similarity_function(self, use_cosine_similarity): + if use_cosine_similarity: + self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + return self._cosine_simililarity + else: + return self._dot_simililarity + + def _get_labels(self): + labels = (np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=-self.batch_size) + np.eye( + (2 * self.batch_size), + 2 * self.batch_size - 1, + k=self.batch_size - 1)).astype(np.int) + labels = torch.from_numpy(labels) + labels = labels.to(self.device) + return labels + + def _get_correlated_mask(self): + mask_samples_from_same_repr = (1 - torch.eye(2 * self.batch_size)).type(torch.bool) + return mask_samples_from_same_repr + + def _dot_simililarity(self, x, y): + v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) + # x shape: (N, 1, C) + # y shape: (1, C, 2N) + # v shape: (N, 2N) + return v + + def _cosine_simililarity(self, x, y): + # x shape: (N, 1, C) + # y shape: (1, 2N, C) + # v shape: (N, 2N) + v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) + return v + + def forward(self, zis, zjs): + negatives = torch.cat([zjs, zis], dim=0) + + logits = self.similarity_function(negatives, negatives) + logits = logits[self.mask_samples_from_same_repr.type(torch.bool)].view(2 * self.batch_size, -1) + logits /= self.temperature + assert logits.shape == (2 * self.batch_size, 2 * self.batch_size - 1), "Shape of negatives not expected." + str( + logits.shape) + + probs = self.softmax(logits) + loss = torch.mean(-torch.sum(self.labels * torch.log(probs), dim=-1)) + return loss diff --git a/models/resnet_simclr.py b/models/resnet_simclr.py index f2b073c..20a8ec0 100644 --- a/models/resnet_simclr.py +++ b/models/resnet_simclr.py @@ -1,4 +1,3 @@ -import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models @@ -6,7 +5,7 @@ import torchvision.models as models class ResNetSimCLR(nn.Module): - def __init__(self, base_model="resnet18", out_dim=64): + def __init__(self, base_model, out_dim): super(ResNetSimCLR, self).__init__() self.resnet_dict = {"resnet18": models.resnet18(pretrained=False), "resnet50": models.resnet50(pretrained=False)} @@ -22,7 +21,9 @@ class ResNetSimCLR(nn.Module): def _get_basemodel(self, model_name): try: - return self.resnet_dict[model_name] + model = self.resnet_dict[model_name] + print("Feature extractor:", model_name) + return model except: raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") diff --git a/run.py b/run.py new file mode 100644 index 0000000..098b42a --- /dev/null +++ b/run.py @@ -0,0 +1,13 @@ +from simclr import SimCLR +import yaml + + +def main(): + config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) + + simclr = SimCLR(config) + simclr.train() + + +if __name__ == "__main__": + main() diff --git a/simclr.py b/simclr.py new file mode 100644 index 0000000..e48e13c --- /dev/null +++ b/simclr.py @@ -0,0 +1,120 @@ +import torch +from models.resnet_simclr import ResNetSimCLR +from torch.utils.tensorboard import SummaryWriter +import torch.nn.functional as F +from loss.nt_xent import NTXentLoss +import os +import shutil +from data_aug.dataset_wrapper import DataSetWrapper +import numpy as np + +torch.manual_seed(0) + + +class SimCLR(object): + + def __init__(self, config): + self.config = config + self.device = self._get_device() + self.writer = SummaryWriter() + self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss']) + + def _get_device(self): + device = 'cuda' if torch.cuda.is_available() else 'gpu' + print("Running on:", device) + return device + + def _step(self, model, xis, xjs, n_iter): + # get the representations and the projections + ris, zis = model(xis) # [N,C] + + # get the representations and the projections + rjs, zjs = model(xjs) # [N,C] + + # normalize projection feature vectors + zis = F.normalize(zis, dim=1) + zjs = F.normalize(zjs, dim=1) + + loss = self.nt_xent_criterion(zis, zjs) + + if n_iter % self.config['log_every_n_steps'] == 0: + self.writer.add_histogram("xi_repr", ris, global_step=n_iter) + self.writer.add_histogram("xi_latent", zis, global_step=n_iter) + self.writer.add_histogram("xj_repr", rjs, global_step=n_iter) + self.writer.add_histogram("xj_latent", zjs, global_step=n_iter) + self.writer.add_scalar('train_loss', loss, global_step=n_iter) + + return loss + + def train(self): + dataset = DataSetWrapper(self.config['batch_size'], **self.config['dataset']) + train_loader, valid_loader = dataset.get_data_loaders() + + model = ResNetSimCLR(**self.config["model"]).to(self.device) + model = self._load_pre_trained_weights(model) + + optimizer = torch.optim.Adam(model.parameters(), 3e-4) + + model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') + + # save config file + self._save_config_file(model_checkpoints_folder) + + n_iter = 0 + valid_n_iter = 0 + best_valid_loss = np.inf + + for epoch_counter in range(self.config['epochs']): + for (xis, xjs), _ in train_loader: + optimizer.zero_grad() + + xis = xis.to(self.device) + xjs = xjs.to(self.device) + + loss = self._step(model, xis, xjs, n_iter) + + loss.backward() + optimizer.step() + n_iter += 1 + + if epoch_counter % self.config['eval_every_n_epochs'] == 0: + + # validation steps + with torch.no_grad(): + model.eval() + + valid_loss = 0.0 + for counter, ((xis, xjs), _) in enumerate(valid_loader): + xis = xis.to(self.device) + xjs = xjs.to(self.device) + + loss = self._step(model, xis, xjs, n_iter) + valid_loss += loss.item() + + valid_loss /= counter + + if valid_loss < best_valid_loss: + # save the model weights + best_valid_loss = valid_loss + torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) + + self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) + valid_n_iter += 1 + + model.train() + + def _save_config_file(self, model_checkpoints_folder): + if not os.path.exists(model_checkpoints_folder): + os.makedirs(model_checkpoints_folder) + shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml')) + + def _load_pre_trained_weights(self, model): + try: + checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints') + state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth')) + model.load_state_dict(state_dict) + print("Loaded pre-trained model with success.") + except FileNotFoundError: + print("Pre-trained weights not found. Training from scratch.") + + return model diff --git a/train.py b/train.py deleted file mode 100644 index c9d0894..0000000 --- a/train.py +++ /dev/null @@ -1,147 +0,0 @@ -import shutil - -import torch -import yaml - -print(torch.__version__) -import torch.optim as optim -import os - -from torchvision import datasets -from torch.utils.tensorboard import SummaryWriter -import torch.nn.functional as F -import numpy as np -from models.resnet_simclr import ResNetSimCLR -from utils import get_similarity_function, get_train_validation_data_loaders -from data_aug.data_transform import DataTransform, get_simclr_data_transform - -torch.manual_seed(0) -np.random.seed(0) - -config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) - -batch_size = config['batch_size'] -out_dim = config['out_dim'] -temperature = config['temperature'] -use_cosine_similarity = config['use_cosine_similarity'] - -data_augment = get_simclr_data_transform(s=config['s'], crop_size=96) - -train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=DataTransform(data_augment)) - -train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, **config) - -# model = Encoder(out_dim=out_dim) -model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim) - -if config['continue_training']: - checkpoints_folder = os.path.join('./runs', config['continue_training'], 'checkpoints') - state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth')) - model.load_state_dict(state_dict) - print("Loaded pre-trained model with success.") - -train_gpu = torch.cuda.is_available() -print("Is gpu available:", train_gpu) - -# moves the model parameters to gpu -if train_gpu: - model = model.cuda() - -criterion = torch.nn.CrossEntropyLoss(reduction='sum') -optimizer = optim.Adam(model.parameters(), 3e-4) - -train_writer = SummaryWriter() - -similarity_func = get_similarity_function(use_cosine_similarity) - -megative_mask = (1 - torch.eye(2 * batch_size)).type(torch.bool) -labels = (np.eye((2 * batch_size), 2 * batch_size - 1, k=-batch_size) + np.eye((2 * batch_size), 2 * batch_size - 1, - k=batch_size - 1)).astype(np.int) -labels = torch.from_numpy(labels) -softmax = torch.nn.Softmax(dim=-1) - -if train_gpu: - labels = labels.cuda() - - -def step(xis, xjs): - # get the representations and the projections - ris, zis = model(xis) # [N,C] - - # get the representations and the projections - rjs, zjs = model(xjs) # [N,C] - - if n_iter % config['log_every_n_steps'] == 0: - train_writer.add_histogram("xi_repr", ris, global_step=n_iter) - train_writer.add_histogram("xi_latent", zis, global_step=n_iter) - train_writer.add_histogram("xj_repr", rjs, global_step=n_iter) - train_writer.add_histogram("xj_latent", zjs, global_step=n_iter) - - # normalize projection feature vectors - zis = F.normalize(zis, dim=1) - zjs = F.normalize(zjs, dim=1) - - negatives = torch.cat([zjs, zis], dim=0) - - logits = similarity_func(negatives, negatives) - logits = logits[megative_mask.type(torch.bool)].view(2 * batch_size, -1) - logits /= temperature - # assert logits.shape == (2 * batch_size, 2 * batch_size - 1), "Shape of negatives not expected." + str( - # logits.shape) - - probs = softmax(logits) - loss = torch.mean(-torch.sum(labels * torch.log(probs), dim=-1)) - - return loss - - -model_checkpoints_folder = os.path.join(train_writer.log_dir, 'checkpoints') -if not os.path.exists(model_checkpoints_folder): - os.makedirs(model_checkpoints_folder) - shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml')) - -n_iter = 0 -valid_n_iter = 0 -best_valid_loss = np.inf - -for epoch_counter in range(config['epochs']): - for (xis, xjs), _ in train_loader: - optimizer.zero_grad() - - if train_gpu: - xis = xis.cuda() - xjs = xjs.cuda() - - loss = step(xis, xjs) - - train_writer.add_scalar('train_loss', loss, global_step=n_iter) - loss.backward() - optimizer.step() - n_iter += 1 - - if epoch_counter % config['eval_every_n_epochs'] == 0: - - # validation steps - with torch.no_grad(): - model.eval() - - valid_loss = 0.0 - for counter, ((xis, xjs), _) in enumerate(valid_loader): - - if train_gpu: - xis = xis.cuda() - xjs = xjs.cuda() - loss = (step(xis, xjs)) - valid_loss += loss.item() - - valid_loss /= counter - - if valid_loss < best_valid_loss: - # save the model weights - best_valid_loss = valid_loss - torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) - - train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) - valid_n_iter += 1 - - model.train() diff --git a/utils.py b/utils.py deleted file mode 100644 index 7759e1d..0000000 --- a/utils.py +++ /dev/null @@ -1,62 +0,0 @@ -import numpy as np -import torch -from torch.utils.data import DataLoader -from torch.utils.data.sampler import SubsetRandomSampler - -np.random.seed(0) -cosine_similarity = torch.nn.CosineSimilarity(dim=-1) - - -def get_train_validation_data_loaders(train_dataset, batch_size, num_workers, valid_size, **ignored): - # obtain training indices that will be used for validation - num_train = len(train_dataset) - indices = list(range(num_train)) - np.random.shuffle(indices) - - split = int(np.floor(valid_size * num_train)) - train_idx, valid_idx = indices[split:], indices[:split] - - # define samplers for obtaining training and validation batches - train_sampler = SubsetRandomSampler(train_idx) - valid_sampler = SubsetRandomSampler(valid_idx) - - train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, - num_workers=num_workers, drop_last=True, shuffle=False) - - valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler, - num_workers=num_workers, drop_last=True) - return train_loader, valid_loader - - -def get_negative_mask(batch_size): - # return a mask that removes the similarity score of equal/similar images. - # this function ensures that only distinct pair of images get their similarity scores - # passed as negative examples - negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool) - for i in range(batch_size): - negative_mask[i, i] = 0 - negative_mask[i, i + batch_size] = 0 - return negative_mask - - -def _dot_simililarity_dim2(x, y): - v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) - # x shape: (N, 1, C) - # y shape: (1, C, 2N) - # v shape: (N, 2N) - return v - - -def _cosine_simililarity_dim2(x, y): - # x shape: (N, 1, C) - # y shape: (1, 2N, C) - # v shape: (N, 2N) - v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) - return v - - -def get_similarity_function(use_cosine_similarity): - if use_cosine_similarity: - return _cosine_simililarity_dim2 - else: - return _dot_simililarity_dim2