mirror of https://github.com/sthalles/SimCLR.git
new pythonic implementation
parent
18b070f3ef
commit
45e3b3b7ef
27
config.yaml
27
config.yaml
|
@ -1,12 +1,19 @@
|
|||
batch_size: 512
|
||||
out_dim: 256
|
||||
s: 1
|
||||
temperature: 0.5
|
||||
base_convnet: "resnet18" # one of: "resnet18 or resnet50"
|
||||
use_cosine_similarity: True
|
||||
epochs: 50
|
||||
num_workers: 0
|
||||
valid_size: 0.05
|
||||
eval_every_n_epochs: 2
|
||||
continue_training: Mar10_21-50-05_thallessilva
|
||||
epochs: 33
|
||||
eval_every_n_epochs: 1
|
||||
fine_tune_from: 'Mar13_20-12-09_thallessilva'
|
||||
log_every_n_steps: 50
|
||||
|
||||
model:
|
||||
out_dim: 256
|
||||
base_model: "resnet18"
|
||||
|
||||
dataset:
|
||||
s: 1
|
||||
input_shape: (96,96,3)
|
||||
num_workers: 0
|
||||
valid_size: 0.05
|
||||
|
||||
loss:
|
||||
temperature: 0.5
|
||||
use_cosine_similarity: True
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
import torchvision.transforms as transforms
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataTransform(object):
|
||||
def __init__(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def __call__(self, sample):
|
||||
xi = self.transform(sample)
|
||||
xj = self.transform(sample)
|
||||
return xi, xj
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
# Implements Gaussian blur as described in the SimCLR paper
|
||||
def __init__(self, kernel_size, min=0.1, max=2.0):
|
||||
self.min = min
|
||||
self.max = max
|
||||
# kernel size is set to be 10% of the image height/width
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def __call__(self, sample):
|
||||
sample = np.array(sample)
|
||||
|
||||
# blur the image with a 50% chance
|
||||
prob = np.random.random_sample()
|
||||
|
||||
if prob < 0.5:
|
||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def get_simclr_data_transform(s, crop_size):
|
||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
||||
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([color_jitter], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
GaussianBlur(kernel_size=int(0.1 * crop_size)),
|
||||
transforms.ToTensor()])
|
||||
return data_transforms
|
|
@ -0,0 +1,68 @@
|
|||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import SubsetRandomSampler
|
||||
import torchvision.transforms as transforms
|
||||
from data_aug.gaussian_blur import GaussianBlur
|
||||
from torchvision import datasets
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
class DataSetWrapper(object):
|
||||
|
||||
def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.valid_size = valid_size
|
||||
self.s = s
|
||||
self.input_shape = eval(input_shape)
|
||||
|
||||
def get_data_loaders(self):
|
||||
data_augment = self._get_simclr_pipeline_transform()
|
||||
|
||||
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True,
|
||||
transform=SimCLRDataTransform(data_augment))
|
||||
|
||||
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
|
||||
return train_loader, valid_loader
|
||||
|
||||
def _get_simclr_pipeline_transform(self):
|
||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
|
||||
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([color_jitter], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
|
||||
transforms.ToTensor()])
|
||||
return data_transforms
|
||||
|
||||
def get_train_validation_data_loaders(self, train_dataset):
|
||||
# obtain training indices that will be used for validation
|
||||
num_train = len(train_dataset)
|
||||
indices = list(range(num_train))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
split = int(np.floor(self.valid_size * num_train))
|
||||
train_idx, valid_idx = indices[split:], indices[:split]
|
||||
|
||||
# define samplers for obtaining training and validation batches
|
||||
train_sampler = SubsetRandomSampler(train_idx)
|
||||
valid_sampler = SubsetRandomSampler(valid_idx)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
|
||||
num_workers=self.num_workers, drop_last=True, shuffle=False)
|
||||
|
||||
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
|
||||
num_workers=self.num_workers, drop_last=True)
|
||||
return train_loader, valid_loader
|
||||
|
||||
|
||||
class SimCLRDataTransform(object):
|
||||
def __init__(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def __call__(self, sample):
|
||||
xi = self.transform(sample)
|
||||
xj = self.transform(sample)
|
||||
return xi, xj
|
|
@ -0,0 +1,25 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
# Implements Gaussian blur as described in the SimCLR paper
|
||||
def __init__(self, kernel_size, min=0.1, max=2.0):
|
||||
self.min = min
|
||||
self.max = max
|
||||
# kernel size is set to be 10% of the image height/width
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def __call__(self, sample):
|
||||
sample = np.array(sample)
|
||||
|
||||
# blur the image with a 50% chance
|
||||
prob = np.random.random_sample()
|
||||
|
||||
if prob < 0.5:
|
||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||
|
||||
return sample
|
|
@ -33,7 +33,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"folder_name = 'Mar10_21-50-05_thallessilva'\n",
|
||||
"folder_name = 'Mar13_20-12-09_thallessilva'\n",
|
||||
"checkpoints_folder = os.path.join('../runs', folder_name, 'checkpoints')\n",
|
||||
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"), Loader=yaml.FullLoader)"
|
||||
]
|
||||
|
@ -52,10 +52,13 @@
|
|||
" 'temperature': 0.5,\n",
|
||||
" 'base_convnet': 'resnet18',\n",
|
||||
" 'use_cosine_similarity': True,\n",
|
||||
" 'epochs': 50,\n",
|
||||
" 'num_workers': 4,\n",
|
||||
" 'epochs': 40,\n",
|
||||
" 'num_workers': 0,\n",
|
||||
" 'valid_size': 0.05,\n",
|
||||
" 'eval_every_n_epochs': 2}"
|
||||
" 'eval_every_n_epochs': 2,\n",
|
||||
" 'continue_training': 'Mar10_21-50-05_thallessilva',\n",
|
||||
" 'log_every_n_steps': 50,\n",
|
||||
" 'input_shape': '(96,96,3)'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
|
@ -204,17 +207,19 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"Logistic Regression feature eval\n",
|
||||
"Train score: 0.4966\n",
|
||||
"Test score: 0.35\n",
|
||||
"Train score: 0.495\n",
|
||||
"Test score: 0.34725\n",
|
||||
"-------------------------------\n",
|
||||
"KNN feature eval\n",
|
||||
"Train score: 0.4036\n",
|
||||
"Test score: 0.300125\n"
|
||||
"Train score: 0.406\n",
|
||||
"Test score: 0.297875\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)"
|
||||
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)\n",
|
||||
"del X_train_pca\n",
|
||||
"del X_test_pca"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -226,16 +231,23 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Feature extractor: resnet18\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -250,7 +262,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -263,7 +275,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -290,7 +302,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -317,34 +329,20 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||||
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||||
"\n",
|
||||
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||||
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||||
"Please also refer to the documentation for alternative solver options:\n",
|
||||
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||||
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Logistic Regression feature eval\n",
|
||||
"Train score: 0.9628\n",
|
||||
"Test score: 0.75\n",
|
||||
"Train score: 0.829\n",
|
||||
"Test score: 0.5445\n",
|
||||
"-------------------------------\n",
|
||||
"KNN feature eval\n",
|
||||
"Train score: 0.7764\n",
|
||||
"Test score: 0.709125\n"
|
||||
"Train score: 0.5934\n",
|
||||
"Test score: 0.47675\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -352,12 +350,15 @@
|
|||
"scaler = preprocessing.StandardScaler()\n",
|
||||
"scaler.fit(X_train_feature)\n",
|
||||
"\n",
|
||||
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)"
|
||||
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)\n",
|
||||
"\n",
|
||||
"del X_train_feature\n",
|
||||
"del X_test_feature"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NTXentLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
|
||||
super(NTXentLoss, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.temperature = temperature
|
||||
self.device = device
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
self.mask_samples_from_same_repr = self._get_correlated_mask()
|
||||
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
||||
self.labels = self._get_labels()
|
||||
|
||||
def _get_similarity_function(self, use_cosine_similarity):
|
||||
if use_cosine_similarity:
|
||||
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
||||
return self._cosine_simililarity
|
||||
else:
|
||||
return self._dot_simililarity
|
||||
|
||||
def _get_labels(self):
|
||||
labels = (np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=-self.batch_size) + np.eye(
|
||||
(2 * self.batch_size),
|
||||
2 * self.batch_size - 1,
|
||||
k=self.batch_size - 1)).astype(np.int)
|
||||
labels = torch.from_numpy(labels)
|
||||
labels = labels.to(self.device)
|
||||
return labels
|
||||
|
||||
def _get_correlated_mask(self):
|
||||
mask_samples_from_same_repr = (1 - torch.eye(2 * self.batch_size)).type(torch.bool)
|
||||
return mask_samples_from_same_repr
|
||||
|
||||
def _dot_simililarity(self, x, y):
|
||||
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (1, C, 2N)
|
||||
# v shape: (N, 2N)
|
||||
return v
|
||||
|
||||
def _cosine_simililarity(self, x, y):
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (1, 2N, C)
|
||||
# v shape: (N, 2N)
|
||||
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
||||
return v
|
||||
|
||||
def forward(self, zis, zjs):
|
||||
negatives = torch.cat([zjs, zis], dim=0)
|
||||
|
||||
logits = self.similarity_function(negatives, negatives)
|
||||
logits = logits[self.mask_samples_from_same_repr.type(torch.bool)].view(2 * self.batch_size, -1)
|
||||
logits /= self.temperature
|
||||
assert logits.shape == (2 * self.batch_size, 2 * self.batch_size - 1), "Shape of negatives not expected." + str(
|
||||
logits.shape)
|
||||
|
||||
probs = self.softmax(logits)
|
||||
loss = torch.mean(-torch.sum(self.labels * torch.log(probs), dim=-1))
|
||||
return loss
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
@ -6,7 +5,7 @@ import torchvision.models as models
|
|||
|
||||
class ResNetSimCLR(nn.Module):
|
||||
|
||||
def __init__(self, base_model="resnet18", out_dim=64):
|
||||
def __init__(self, base_model, out_dim):
|
||||
super(ResNetSimCLR, self).__init__()
|
||||
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
|
||||
"resnet50": models.resnet50(pretrained=False)}
|
||||
|
@ -22,7 +21,9 @@ class ResNetSimCLR(nn.Module):
|
|||
|
||||
def _get_basemodel(self, model_name):
|
||||
try:
|
||||
return self.resnet_dict[model_name]
|
||||
model = self.resnet_dict[model_name]
|
||||
print("Feature extractor:", model_name)
|
||||
return model
|
||||
except:
|
||||
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
from simclr import SimCLR
|
||||
import yaml
|
||||
|
||||
|
||||
def main():
|
||||
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
||||
|
||||
simclr = SimCLR(config)
|
||||
simclr.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,120 @@
|
|||
import torch
|
||||
from models.resnet_simclr import ResNetSimCLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
from loss.nt_xent import NTXentLoss
|
||||
import os
|
||||
import shutil
|
||||
from data_aug.dataset_wrapper import DataSetWrapper
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class SimCLR(object):
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.device = self._get_device()
|
||||
self.writer = SummaryWriter()
|
||||
self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
|
||||
|
||||
def _get_device(self):
|
||||
device = 'cuda' if torch.cuda.is_available() else 'gpu'
|
||||
print("Running on:", device)
|
||||
return device
|
||||
|
||||
def _step(self, model, xis, xjs, n_iter):
|
||||
# get the representations and the projections
|
||||
ris, zis = model(xis) # [N,C]
|
||||
|
||||
# get the representations and the projections
|
||||
rjs, zjs = model(xjs) # [N,C]
|
||||
|
||||
# normalize projection feature vectors
|
||||
zis = F.normalize(zis, dim=1)
|
||||
zjs = F.normalize(zjs, dim=1)
|
||||
|
||||
loss = self.nt_xent_criterion(zis, zjs)
|
||||
|
||||
if n_iter % self.config['log_every_n_steps'] == 0:
|
||||
self.writer.add_histogram("xi_repr", ris, global_step=n_iter)
|
||||
self.writer.add_histogram("xi_latent", zis, global_step=n_iter)
|
||||
self.writer.add_histogram("xj_repr", rjs, global_step=n_iter)
|
||||
self.writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
||||
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
||||
|
||||
return loss
|
||||
|
||||
def train(self):
|
||||
dataset = DataSetWrapper(self.config['batch_size'], **self.config['dataset'])
|
||||
train_loader, valid_loader = dataset.get_data_loaders()
|
||||
|
||||
model = ResNetSimCLR(**self.config["model"]).to(self.device)
|
||||
model = self._load_pre_trained_weights(model)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), 3e-4)
|
||||
|
||||
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
||||
|
||||
# save config file
|
||||
self._save_config_file(model_checkpoints_folder)
|
||||
|
||||
n_iter = 0
|
||||
valid_n_iter = 0
|
||||
best_valid_loss = np.inf
|
||||
|
||||
for epoch_counter in range(self.config['epochs']):
|
||||
for (xis, xjs), _ in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
xis = xis.to(self.device)
|
||||
xjs = xjs.to(self.device)
|
||||
|
||||
loss = self._step(model, xis, xjs, n_iter)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
n_iter += 1
|
||||
|
||||
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
|
||||
|
||||
# validation steps
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
|
||||
valid_loss = 0.0
|
||||
for counter, ((xis, xjs), _) in enumerate(valid_loader):
|
||||
xis = xis.to(self.device)
|
||||
xjs = xjs.to(self.device)
|
||||
|
||||
loss = self._step(model, xis, xjs, n_iter)
|
||||
valid_loss += loss.item()
|
||||
|
||||
valid_loss /= counter
|
||||
|
||||
if valid_loss < best_valid_loss:
|
||||
# save the model weights
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
||||
|
||||
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
||||
valid_n_iter += 1
|
||||
|
||||
model.train()
|
||||
|
||||
def _save_config_file(self, model_checkpoints_folder):
|
||||
if not os.path.exists(model_checkpoints_folder):
|
||||
os.makedirs(model_checkpoints_folder)
|
||||
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
|
||||
|
||||
def _load_pre_trained_weights(self, model):
|
||||
try:
|
||||
checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
|
||||
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
|
||||
model.load_state_dict(state_dict)
|
||||
print("Loaded pre-trained model with success.")
|
||||
except FileNotFoundError:
|
||||
print("Pre-trained weights not found. Training from scratch.")
|
||||
|
||||
return model
|
147
train.py
147
train.py
|
@ -1,147 +0,0 @@
|
|||
import shutil
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
print(torch.__version__)
|
||||
import torch.optim as optim
|
||||
import os
|
||||
|
||||
from torchvision import datasets
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from models.resnet_simclr import ResNetSimCLR
|
||||
from utils import get_similarity_function, get_train_validation_data_loaders
|
||||
from data_aug.data_transform import DataTransform, get_simclr_data_transform
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
||||
|
||||
batch_size = config['batch_size']
|
||||
out_dim = config['out_dim']
|
||||
temperature = config['temperature']
|
||||
use_cosine_similarity = config['use_cosine_similarity']
|
||||
|
||||
data_augment = get_simclr_data_transform(s=config['s'], crop_size=96)
|
||||
|
||||
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=DataTransform(data_augment))
|
||||
|
||||
train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, **config)
|
||||
|
||||
# model = Encoder(out_dim=out_dim)
|
||||
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
||||
|
||||
if config['continue_training']:
|
||||
checkpoints_folder = os.path.join('./runs', config['continue_training'], 'checkpoints')
|
||||
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
|
||||
model.load_state_dict(state_dict)
|
||||
print("Loaded pre-trained model with success.")
|
||||
|
||||
train_gpu = torch.cuda.is_available()
|
||||
print("Is gpu available:", train_gpu)
|
||||
|
||||
# moves the model parameters to gpu
|
||||
if train_gpu:
|
||||
model = model.cuda()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
|
||||
optimizer = optim.Adam(model.parameters(), 3e-4)
|
||||
|
||||
train_writer = SummaryWriter()
|
||||
|
||||
similarity_func = get_similarity_function(use_cosine_similarity)
|
||||
|
||||
megative_mask = (1 - torch.eye(2 * batch_size)).type(torch.bool)
|
||||
labels = (np.eye((2 * batch_size), 2 * batch_size - 1, k=-batch_size) + np.eye((2 * batch_size), 2 * batch_size - 1,
|
||||
k=batch_size - 1)).astype(np.int)
|
||||
labels = torch.from_numpy(labels)
|
||||
softmax = torch.nn.Softmax(dim=-1)
|
||||
|
||||
if train_gpu:
|
||||
labels = labels.cuda()
|
||||
|
||||
|
||||
def step(xis, xjs):
|
||||
# get the representations and the projections
|
||||
ris, zis = model(xis) # [N,C]
|
||||
|
||||
# get the representations and the projections
|
||||
rjs, zjs = model(xjs) # [N,C]
|
||||
|
||||
if n_iter % config['log_every_n_steps'] == 0:
|
||||
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
|
||||
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
|
||||
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
|
||||
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
||||
|
||||
# normalize projection feature vectors
|
||||
zis = F.normalize(zis, dim=1)
|
||||
zjs = F.normalize(zjs, dim=1)
|
||||
|
||||
negatives = torch.cat([zjs, zis], dim=0)
|
||||
|
||||
logits = similarity_func(negatives, negatives)
|
||||
logits = logits[megative_mask.type(torch.bool)].view(2 * batch_size, -1)
|
||||
logits /= temperature
|
||||
# assert logits.shape == (2 * batch_size, 2 * batch_size - 1), "Shape of negatives not expected." + str(
|
||||
# logits.shape)
|
||||
|
||||
probs = softmax(logits)
|
||||
loss = torch.mean(-torch.sum(labels * torch.log(probs), dim=-1))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
model_checkpoints_folder = os.path.join(train_writer.log_dir, 'checkpoints')
|
||||
if not os.path.exists(model_checkpoints_folder):
|
||||
os.makedirs(model_checkpoints_folder)
|
||||
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
|
||||
|
||||
n_iter = 0
|
||||
valid_n_iter = 0
|
||||
best_valid_loss = np.inf
|
||||
|
||||
for epoch_counter in range(config['epochs']):
|
||||
for (xis, xjs), _ in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
if train_gpu:
|
||||
xis = xis.cuda()
|
||||
xjs = xjs.cuda()
|
||||
|
||||
loss = step(xis, xjs)
|
||||
|
||||
train_writer.add_scalar('train_loss', loss, global_step=n_iter)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
n_iter += 1
|
||||
|
||||
if epoch_counter % config['eval_every_n_epochs'] == 0:
|
||||
|
||||
# validation steps
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
|
||||
valid_loss = 0.0
|
||||
for counter, ((xis, xjs), _) in enumerate(valid_loader):
|
||||
|
||||
if train_gpu:
|
||||
xis = xis.cuda()
|
||||
xjs = xjs.cuda()
|
||||
loss = (step(xis, xjs))
|
||||
valid_loss += loss.item()
|
||||
|
||||
valid_loss /= counter
|
||||
|
||||
if valid_loss < best_valid_loss:
|
||||
# save the model weights
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
||||
|
||||
train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
||||
valid_n_iter += 1
|
||||
|
||||
model.train()
|
62
utils.py
62
utils.py
|
@ -1,62 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import SubsetRandomSampler
|
||||
|
||||
np.random.seed(0)
|
||||
cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
||||
|
||||
|
||||
def get_train_validation_data_loaders(train_dataset, batch_size, num_workers, valid_size, **ignored):
|
||||
# obtain training indices that will be used for validation
|
||||
num_train = len(train_dataset)
|
||||
indices = list(range(num_train))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
split = int(np.floor(valid_size * num_train))
|
||||
train_idx, valid_idx = indices[split:], indices[:split]
|
||||
|
||||
# define samplers for obtaining training and validation batches
|
||||
train_sampler = SubsetRandomSampler(train_idx)
|
||||
valid_sampler = SubsetRandomSampler(valid_idx)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
|
||||
num_workers=num_workers, drop_last=True, shuffle=False)
|
||||
|
||||
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler,
|
||||
num_workers=num_workers, drop_last=True)
|
||||
return train_loader, valid_loader
|
||||
|
||||
|
||||
def get_negative_mask(batch_size):
|
||||
# return a mask that removes the similarity score of equal/similar images.
|
||||
# this function ensures that only distinct pair of images get their similarity scores
|
||||
# passed as negative examples
|
||||
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
|
||||
for i in range(batch_size):
|
||||
negative_mask[i, i] = 0
|
||||
negative_mask[i, i + batch_size] = 0
|
||||
return negative_mask
|
||||
|
||||
|
||||
def _dot_simililarity_dim2(x, y):
|
||||
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (1, C, 2N)
|
||||
# v shape: (N, 2N)
|
||||
return v
|
||||
|
||||
|
||||
def _cosine_simililarity_dim2(x, y):
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (1, 2N, C)
|
||||
# v shape: (N, 2N)
|
||||
v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
||||
return v
|
||||
|
||||
|
||||
def get_similarity_function(use_cosine_similarity):
|
||||
if use_cosine_similarity:
|
||||
return _cosine_simililarity_dim2
|
||||
else:
|
||||
return _dot_simililarity_dim2
|
Loading…
Reference in New Issue