mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
new pythonic implementation
This commit is contained in:
parent
18b070f3ef
commit
45e3b3b7ef
21
config.yaml
21
config.yaml
@ -1,12 +1,19 @@
|
|||||||
batch_size: 512
|
batch_size: 512
|
||||||
|
epochs: 33
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
fine_tune_from: 'Mar13_20-12-09_thallessilva'
|
||||||
|
log_every_n_steps: 50
|
||||||
|
|
||||||
|
model:
|
||||||
out_dim: 256
|
out_dim: 256
|
||||||
|
base_model: "resnet18"
|
||||||
|
|
||||||
|
dataset:
|
||||||
s: 1
|
s: 1
|
||||||
temperature: 0.5
|
input_shape: (96,96,3)
|
||||||
base_convnet: "resnet18" # one of: "resnet18 or resnet50"
|
|
||||||
use_cosine_similarity: True
|
|
||||||
epochs: 50
|
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
valid_size: 0.05
|
valid_size: 0.05
|
||||||
eval_every_n_epochs: 2
|
|
||||||
continue_training: Mar10_21-50-05_thallessilva
|
loss:
|
||||||
log_every_n_steps: 50
|
temperature: 0.5
|
||||||
|
use_cosine_similarity: True
|
||||||
|
@ -1,46 +0,0 @@
|
|||||||
import torchvision.transforms as transforms
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class DataTransform(object):
|
|
||||||
def __init__(self, transform):
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
xi = self.transform(sample)
|
|
||||||
xj = self.transform(sample)
|
|
||||||
return xi, xj
|
|
||||||
|
|
||||||
|
|
||||||
class GaussianBlur(object):
|
|
||||||
# Implements Gaussian blur as described in the SimCLR paper
|
|
||||||
def __init__(self, kernel_size, min=0.1, max=2.0):
|
|
||||||
self.min = min
|
|
||||||
self.max = max
|
|
||||||
# kernel size is set to be 10% of the image height/width
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
sample = np.array(sample)
|
|
||||||
|
|
||||||
# blur the image with a 50% chance
|
|
||||||
prob = np.random.random_sample()
|
|
||||||
|
|
||||||
if prob < 0.5:
|
|
||||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
|
||||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
def get_simclr_data_transform(s, crop_size):
|
|
||||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
|
||||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
|
||||||
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomApply([color_jitter], p=0.8),
|
|
||||||
transforms.RandomGrayscale(p=0.2),
|
|
||||||
GaussianBlur(kernel_size=int(0.1 * crop_size)),
|
|
||||||
transforms.ToTensor()])
|
|
||||||
return data_transforms
|
|
68
data_aug/dataset_wrapper.py
Normal file
68
data_aug/dataset_wrapper.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.sampler import SubsetRandomSampler
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from data_aug.gaussian_blur import GaussianBlur
|
||||||
|
from torchvision import datasets
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
class DataSetWrapper(object):
|
||||||
|
|
||||||
|
def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.valid_size = valid_size
|
||||||
|
self.s = s
|
||||||
|
self.input_shape = eval(input_shape)
|
||||||
|
|
||||||
|
def get_data_loaders(self):
|
||||||
|
data_augment = self._get_simclr_pipeline_transform()
|
||||||
|
|
||||||
|
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True,
|
||||||
|
transform=SimCLRDataTransform(data_augment))
|
||||||
|
|
||||||
|
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
def _get_simclr_pipeline_transform(self):
|
||||||
|
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||||
|
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
|
||||||
|
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomApply([color_jitter], p=0.8),
|
||||||
|
transforms.RandomGrayscale(p=0.2),
|
||||||
|
GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
|
||||||
|
transforms.ToTensor()])
|
||||||
|
return data_transforms
|
||||||
|
|
||||||
|
def get_train_validation_data_loaders(self, train_dataset):
|
||||||
|
# obtain training indices that will be used for validation
|
||||||
|
num_train = len(train_dataset)
|
||||||
|
indices = list(range(num_train))
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
split = int(np.floor(self.valid_size * num_train))
|
||||||
|
train_idx, valid_idx = indices[split:], indices[:split]
|
||||||
|
|
||||||
|
# define samplers for obtaining training and validation batches
|
||||||
|
train_sampler = SubsetRandomSampler(train_idx)
|
||||||
|
valid_sampler = SubsetRandomSampler(valid_idx)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
|
||||||
|
num_workers=self.num_workers, drop_last=True, shuffle=False)
|
||||||
|
|
||||||
|
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
|
||||||
|
num_workers=self.num_workers, drop_last=True)
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
|
||||||
|
class SimCLRDataTransform(object):
|
||||||
|
def __init__(self, transform):
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
xi = self.transform(sample)
|
||||||
|
xj = self.transform(sample)
|
||||||
|
return xi, xj
|
25
data_aug/gaussian_blur.py
Normal file
25
data_aug/gaussian_blur.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianBlur(object):
|
||||||
|
# Implements Gaussian blur as described in the SimCLR paper
|
||||||
|
def __init__(self, kernel_size, min=0.1, max=2.0):
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
# kernel size is set to be 10% of the image height/width
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
sample = np.array(sample)
|
||||||
|
|
||||||
|
# blur the image with a 50% chance
|
||||||
|
prob = np.random.random_sample()
|
||||||
|
|
||||||
|
if prob < 0.5:
|
||||||
|
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||||
|
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||||
|
|
||||||
|
return sample
|
@ -33,7 +33,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"folder_name = 'Mar10_21-50-05_thallessilva'\n",
|
"folder_name = 'Mar13_20-12-09_thallessilva'\n",
|
||||||
"checkpoints_folder = os.path.join('../runs', folder_name, 'checkpoints')\n",
|
"checkpoints_folder = os.path.join('../runs', folder_name, 'checkpoints')\n",
|
||||||
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"), Loader=yaml.FullLoader)"
|
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"), Loader=yaml.FullLoader)"
|
||||||
]
|
]
|
||||||
@ -52,10 +52,13 @@
|
|||||||
" 'temperature': 0.5,\n",
|
" 'temperature': 0.5,\n",
|
||||||
" 'base_convnet': 'resnet18',\n",
|
" 'base_convnet': 'resnet18',\n",
|
||||||
" 'use_cosine_similarity': True,\n",
|
" 'use_cosine_similarity': True,\n",
|
||||||
" 'epochs': 50,\n",
|
" 'epochs': 40,\n",
|
||||||
" 'num_workers': 4,\n",
|
" 'num_workers': 0,\n",
|
||||||
" 'valid_size': 0.05,\n",
|
" 'valid_size': 0.05,\n",
|
||||||
" 'eval_every_n_epochs': 2}"
|
" 'eval_every_n_epochs': 2,\n",
|
||||||
|
" 'continue_training': 'Mar10_21-50-05_thallessilva',\n",
|
||||||
|
" 'log_every_n_steps': 50,\n",
|
||||||
|
" 'input_shape': '(96,96,3)'}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
@ -204,17 +207,19 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Logistic Regression feature eval\n",
|
"Logistic Regression feature eval\n",
|
||||||
"Train score: 0.4966\n",
|
"Train score: 0.495\n",
|
||||||
"Test score: 0.35\n",
|
"Test score: 0.34725\n",
|
||||||
"-------------------------------\n",
|
"-------------------------------\n",
|
||||||
"KNN feature eval\n",
|
"KNN feature eval\n",
|
||||||
"Train score: 0.4036\n",
|
"Train score: 0.406\n",
|
||||||
"Test score: 0.300125\n"
|
"Test score: 0.297875\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)"
|
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)\n",
|
||||||
|
"del X_train_pca\n",
|
||||||
|
"del X_test_pca"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -226,16 +231,23 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Feature extractor: resnet18\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<All keys matched successfully>"
|
"<All keys matched successfully>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 12,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -250,7 +262,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 14,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -263,7 +275,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 15,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -290,7 +302,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 16,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -317,34 +329,20 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 17,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|
||||||
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|
||||||
"\n",
|
|
||||||
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|
||||||
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|
||||||
"Please also refer to the documentation for alternative solver options:\n",
|
|
||||||
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|
||||||
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Logistic Regression feature eval\n",
|
"Logistic Regression feature eval\n",
|
||||||
"Train score: 0.9628\n",
|
"Train score: 0.829\n",
|
||||||
"Test score: 0.75\n",
|
"Test score: 0.5445\n",
|
||||||
"-------------------------------\n",
|
"-------------------------------\n",
|
||||||
"KNN feature eval\n",
|
"KNN feature eval\n",
|
||||||
"Train score: 0.7764\n",
|
"Train score: 0.5934\n",
|
||||||
"Test score: 0.709125\n"
|
"Test score: 0.47675\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -352,12 +350,15 @@
|
|||||||
"scaler = preprocessing.StandardScaler()\n",
|
"scaler = preprocessing.StandardScaler()\n",
|
||||||
"scaler.fit(X_train_feature)\n",
|
"scaler.fit(X_train_feature)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)"
|
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)\n",
|
||||||
|
"\n",
|
||||||
|
"del X_train_feature\n",
|
||||||
|
"del X_test_feature"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 18,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
62
loss/nt_xent.py
Normal file
62
loss/nt_xent.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class NTXentLoss(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
|
||||||
|
super(NTXentLoss, self).__init__()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.temperature = temperature
|
||||||
|
self.device = device
|
||||||
|
self.softmax = torch.nn.Softmax(dim=-1)
|
||||||
|
self.mask_samples_from_same_repr = self._get_correlated_mask()
|
||||||
|
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
||||||
|
self.labels = self._get_labels()
|
||||||
|
|
||||||
|
def _get_similarity_function(self, use_cosine_similarity):
|
||||||
|
if use_cosine_similarity:
|
||||||
|
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
||||||
|
return self._cosine_simililarity
|
||||||
|
else:
|
||||||
|
return self._dot_simililarity
|
||||||
|
|
||||||
|
def _get_labels(self):
|
||||||
|
labels = (np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=-self.batch_size) + np.eye(
|
||||||
|
(2 * self.batch_size),
|
||||||
|
2 * self.batch_size - 1,
|
||||||
|
k=self.batch_size - 1)).astype(np.int)
|
||||||
|
labels = torch.from_numpy(labels)
|
||||||
|
labels = labels.to(self.device)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def _get_correlated_mask(self):
|
||||||
|
mask_samples_from_same_repr = (1 - torch.eye(2 * self.batch_size)).type(torch.bool)
|
||||||
|
return mask_samples_from_same_repr
|
||||||
|
|
||||||
|
def _dot_simililarity(self, x, y):
|
||||||
|
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
||||||
|
# x shape: (N, 1, C)
|
||||||
|
# y shape: (1, C, 2N)
|
||||||
|
# v shape: (N, 2N)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def _cosine_simililarity(self, x, y):
|
||||||
|
# x shape: (N, 1, C)
|
||||||
|
# y shape: (1, 2N, C)
|
||||||
|
# v shape: (N, 2N)
|
||||||
|
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
||||||
|
return v
|
||||||
|
|
||||||
|
def forward(self, zis, zjs):
|
||||||
|
negatives = torch.cat([zjs, zis], dim=0)
|
||||||
|
|
||||||
|
logits = self.similarity_function(negatives, negatives)
|
||||||
|
logits = logits[self.mask_samples_from_same_repr.type(torch.bool)].view(2 * self.batch_size, -1)
|
||||||
|
logits /= self.temperature
|
||||||
|
assert logits.shape == (2 * self.batch_size, 2 * self.batch_size - 1), "Shape of negatives not expected." + str(
|
||||||
|
logits.shape)
|
||||||
|
|
||||||
|
probs = self.softmax(logits)
|
||||||
|
loss = torch.mean(-torch.sum(self.labels * torch.log(probs), dim=-1))
|
||||||
|
return loss
|
@ -1,4 +1,3 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
@ -6,7 +5,7 @@ import torchvision.models as models
|
|||||||
|
|
||||||
class ResNetSimCLR(nn.Module):
|
class ResNetSimCLR(nn.Module):
|
||||||
|
|
||||||
def __init__(self, base_model="resnet18", out_dim=64):
|
def __init__(self, base_model, out_dim):
|
||||||
super(ResNetSimCLR, self).__init__()
|
super(ResNetSimCLR, self).__init__()
|
||||||
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
|
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
|
||||||
"resnet50": models.resnet50(pretrained=False)}
|
"resnet50": models.resnet50(pretrained=False)}
|
||||||
@ -22,7 +21,9 @@ class ResNetSimCLR(nn.Module):
|
|||||||
|
|
||||||
def _get_basemodel(self, model_name):
|
def _get_basemodel(self, model_name):
|
||||||
try:
|
try:
|
||||||
return self.resnet_dict[model_name]
|
model = self.resnet_dict[model_name]
|
||||||
|
print("Feature extractor:", model_name)
|
||||||
|
return model
|
||||||
except:
|
except:
|
||||||
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
||||||
|
|
||||||
|
13
run.py
Normal file
13
run.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from simclr import SimCLR
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
simclr = SimCLR(config)
|
||||||
|
simclr.train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
120
simclr.py
Normal file
120
simclr.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import torch
|
||||||
|
from models.resnet_simclr import ResNetSimCLR
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from loss.nt_xent import NTXentLoss
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from data_aug.dataset_wrapper import DataSetWrapper
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
class SimCLR(object):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.device = self._get_device()
|
||||||
|
self.writer = SummaryWriter()
|
||||||
|
self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
|
||||||
|
|
||||||
|
def _get_device(self):
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'gpu'
|
||||||
|
print("Running on:", device)
|
||||||
|
return device
|
||||||
|
|
||||||
|
def _step(self, model, xis, xjs, n_iter):
|
||||||
|
# get the representations and the projections
|
||||||
|
ris, zis = model(xis) # [N,C]
|
||||||
|
|
||||||
|
# get the representations and the projections
|
||||||
|
rjs, zjs = model(xjs) # [N,C]
|
||||||
|
|
||||||
|
# normalize projection feature vectors
|
||||||
|
zis = F.normalize(zis, dim=1)
|
||||||
|
zjs = F.normalize(zjs, dim=1)
|
||||||
|
|
||||||
|
loss = self.nt_xent_criterion(zis, zjs)
|
||||||
|
|
||||||
|
if n_iter % self.config['log_every_n_steps'] == 0:
|
||||||
|
self.writer.add_histogram("xi_repr", ris, global_step=n_iter)
|
||||||
|
self.writer.add_histogram("xi_latent", zis, global_step=n_iter)
|
||||||
|
self.writer.add_histogram("xj_repr", rjs, global_step=n_iter)
|
||||||
|
self.writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
||||||
|
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
dataset = DataSetWrapper(self.config['batch_size'], **self.config['dataset'])
|
||||||
|
train_loader, valid_loader = dataset.get_data_loaders()
|
||||||
|
|
||||||
|
model = ResNetSimCLR(**self.config["model"]).to(self.device)
|
||||||
|
model = self._load_pre_trained_weights(model)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), 3e-4)
|
||||||
|
|
||||||
|
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
||||||
|
|
||||||
|
# save config file
|
||||||
|
self._save_config_file(model_checkpoints_folder)
|
||||||
|
|
||||||
|
n_iter = 0
|
||||||
|
valid_n_iter = 0
|
||||||
|
best_valid_loss = np.inf
|
||||||
|
|
||||||
|
for epoch_counter in range(self.config['epochs']):
|
||||||
|
for (xis, xjs), _ in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
xis = xis.to(self.device)
|
||||||
|
xjs = xjs.to(self.device)
|
||||||
|
|
||||||
|
loss = self._step(model, xis, xjs, n_iter)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
n_iter += 1
|
||||||
|
|
||||||
|
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
|
||||||
|
|
||||||
|
# validation steps
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
valid_loss = 0.0
|
||||||
|
for counter, ((xis, xjs), _) in enumerate(valid_loader):
|
||||||
|
xis = xis.to(self.device)
|
||||||
|
xjs = xjs.to(self.device)
|
||||||
|
|
||||||
|
loss = self._step(model, xis, xjs, n_iter)
|
||||||
|
valid_loss += loss.item()
|
||||||
|
|
||||||
|
valid_loss /= counter
|
||||||
|
|
||||||
|
if valid_loss < best_valid_loss:
|
||||||
|
# save the model weights
|
||||||
|
best_valid_loss = valid_loss
|
||||||
|
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
||||||
|
|
||||||
|
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
||||||
|
valid_n_iter += 1
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
def _save_config_file(self, model_checkpoints_folder):
|
||||||
|
if not os.path.exists(model_checkpoints_folder):
|
||||||
|
os.makedirs(model_checkpoints_folder)
|
||||||
|
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
|
||||||
|
|
||||||
|
def _load_pre_trained_weights(self, model):
|
||||||
|
try:
|
||||||
|
checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
|
||||||
|
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
print("Loaded pre-trained model with success.")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("Pre-trained weights not found. Training from scratch.")
|
||||||
|
|
||||||
|
return model
|
147
train.py
147
train.py
@ -1,147 +0,0 @@
|
|||||||
import shutil
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
print(torch.__version__)
|
|
||||||
import torch.optim as optim
|
|
||||||
import os
|
|
||||||
|
|
||||||
from torchvision import datasets
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from models.resnet_simclr import ResNetSimCLR
|
|
||||||
from utils import get_similarity_function, get_train_validation_data_loaders
|
|
||||||
from data_aug.data_transform import DataTransform, get_simclr_data_transform
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
np.random.seed(0)
|
|
||||||
|
|
||||||
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
batch_size = config['batch_size']
|
|
||||||
out_dim = config['out_dim']
|
|
||||||
temperature = config['temperature']
|
|
||||||
use_cosine_similarity = config['use_cosine_similarity']
|
|
||||||
|
|
||||||
data_augment = get_simclr_data_transform(s=config['s'], crop_size=96)
|
|
||||||
|
|
||||||
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=DataTransform(data_augment))
|
|
||||||
|
|
||||||
train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, **config)
|
|
||||||
|
|
||||||
# model = Encoder(out_dim=out_dim)
|
|
||||||
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
|
||||||
|
|
||||||
if config['continue_training']:
|
|
||||||
checkpoints_folder = os.path.join('./runs', config['continue_training'], 'checkpoints')
|
|
||||||
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
print("Loaded pre-trained model with success.")
|
|
||||||
|
|
||||||
train_gpu = torch.cuda.is_available()
|
|
||||||
print("Is gpu available:", train_gpu)
|
|
||||||
|
|
||||||
# moves the model parameters to gpu
|
|
||||||
if train_gpu:
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
|
|
||||||
optimizer = optim.Adam(model.parameters(), 3e-4)
|
|
||||||
|
|
||||||
train_writer = SummaryWriter()
|
|
||||||
|
|
||||||
similarity_func = get_similarity_function(use_cosine_similarity)
|
|
||||||
|
|
||||||
megative_mask = (1 - torch.eye(2 * batch_size)).type(torch.bool)
|
|
||||||
labels = (np.eye((2 * batch_size), 2 * batch_size - 1, k=-batch_size) + np.eye((2 * batch_size), 2 * batch_size - 1,
|
|
||||||
k=batch_size - 1)).astype(np.int)
|
|
||||||
labels = torch.from_numpy(labels)
|
|
||||||
softmax = torch.nn.Softmax(dim=-1)
|
|
||||||
|
|
||||||
if train_gpu:
|
|
||||||
labels = labels.cuda()
|
|
||||||
|
|
||||||
|
|
||||||
def step(xis, xjs):
|
|
||||||
# get the representations and the projections
|
|
||||||
ris, zis = model(xis) # [N,C]
|
|
||||||
|
|
||||||
# get the representations and the projections
|
|
||||||
rjs, zjs = model(xjs) # [N,C]
|
|
||||||
|
|
||||||
if n_iter % config['log_every_n_steps'] == 0:
|
|
||||||
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
|
|
||||||
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
|
|
||||||
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
|
|
||||||
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
|
||||||
|
|
||||||
# normalize projection feature vectors
|
|
||||||
zis = F.normalize(zis, dim=1)
|
|
||||||
zjs = F.normalize(zjs, dim=1)
|
|
||||||
|
|
||||||
negatives = torch.cat([zjs, zis], dim=0)
|
|
||||||
|
|
||||||
logits = similarity_func(negatives, negatives)
|
|
||||||
logits = logits[megative_mask.type(torch.bool)].view(2 * batch_size, -1)
|
|
||||||
logits /= temperature
|
|
||||||
# assert logits.shape == (2 * batch_size, 2 * batch_size - 1), "Shape of negatives not expected." + str(
|
|
||||||
# logits.shape)
|
|
||||||
|
|
||||||
probs = softmax(logits)
|
|
||||||
loss = torch.mean(-torch.sum(labels * torch.log(probs), dim=-1))
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
model_checkpoints_folder = os.path.join(train_writer.log_dir, 'checkpoints')
|
|
||||||
if not os.path.exists(model_checkpoints_folder):
|
|
||||||
os.makedirs(model_checkpoints_folder)
|
|
||||||
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
|
|
||||||
|
|
||||||
n_iter = 0
|
|
||||||
valid_n_iter = 0
|
|
||||||
best_valid_loss = np.inf
|
|
||||||
|
|
||||||
for epoch_counter in range(config['epochs']):
|
|
||||||
for (xis, xjs), _ in train_loader:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
if train_gpu:
|
|
||||||
xis = xis.cuda()
|
|
||||||
xjs = xjs.cuda()
|
|
||||||
|
|
||||||
loss = step(xis, xjs)
|
|
||||||
|
|
||||||
train_writer.add_scalar('train_loss', loss, global_step=n_iter)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
n_iter += 1
|
|
||||||
|
|
||||||
if epoch_counter % config['eval_every_n_epochs'] == 0:
|
|
||||||
|
|
||||||
# validation steps
|
|
||||||
with torch.no_grad():
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
valid_loss = 0.0
|
|
||||||
for counter, ((xis, xjs), _) in enumerate(valid_loader):
|
|
||||||
|
|
||||||
if train_gpu:
|
|
||||||
xis = xis.cuda()
|
|
||||||
xjs = xjs.cuda()
|
|
||||||
loss = (step(xis, xjs))
|
|
||||||
valid_loss += loss.item()
|
|
||||||
|
|
||||||
valid_loss /= counter
|
|
||||||
|
|
||||||
if valid_loss < best_valid_loss:
|
|
||||||
# save the model weights
|
|
||||||
best_valid_loss = valid_loss
|
|
||||||
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
|
||||||
|
|
||||||
train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
|
||||||
valid_n_iter += 1
|
|
||||||
|
|
||||||
model.train()
|
|
62
utils.py
62
utils.py
@ -1,62 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data.sampler import SubsetRandomSampler
|
|
||||||
|
|
||||||
np.random.seed(0)
|
|
||||||
cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_validation_data_loaders(train_dataset, batch_size, num_workers, valid_size, **ignored):
|
|
||||||
# obtain training indices that will be used for validation
|
|
||||||
num_train = len(train_dataset)
|
|
||||||
indices = list(range(num_train))
|
|
||||||
np.random.shuffle(indices)
|
|
||||||
|
|
||||||
split = int(np.floor(valid_size * num_train))
|
|
||||||
train_idx, valid_idx = indices[split:], indices[:split]
|
|
||||||
|
|
||||||
# define samplers for obtaining training and validation batches
|
|
||||||
train_sampler = SubsetRandomSampler(train_idx)
|
|
||||||
valid_sampler = SubsetRandomSampler(valid_idx)
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
|
|
||||||
num_workers=num_workers, drop_last=True, shuffle=False)
|
|
||||||
|
|
||||||
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler,
|
|
||||||
num_workers=num_workers, drop_last=True)
|
|
||||||
return train_loader, valid_loader
|
|
||||||
|
|
||||||
|
|
||||||
def get_negative_mask(batch_size):
|
|
||||||
# return a mask that removes the similarity score of equal/similar images.
|
|
||||||
# this function ensures that only distinct pair of images get their similarity scores
|
|
||||||
# passed as negative examples
|
|
||||||
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
|
|
||||||
for i in range(batch_size):
|
|
||||||
negative_mask[i, i] = 0
|
|
||||||
negative_mask[i, i + batch_size] = 0
|
|
||||||
return negative_mask
|
|
||||||
|
|
||||||
|
|
||||||
def _dot_simililarity_dim2(x, y):
|
|
||||||
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
|
||||||
# x shape: (N, 1, C)
|
|
||||||
# y shape: (1, C, 2N)
|
|
||||||
# v shape: (N, 2N)
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def _cosine_simililarity_dim2(x, y):
|
|
||||||
# x shape: (N, 1, C)
|
|
||||||
# y shape: (1, 2N, C)
|
|
||||||
# v shape: (N, 2N)
|
|
||||||
v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def get_similarity_function(use_cosine_similarity):
|
|
||||||
if use_cosine_similarity:
|
|
||||||
return _cosine_simililarity_dim2
|
|
||||||
else:
|
|
||||||
return _dot_simililarity_dim2
|
|
Loading…
x
Reference in New Issue
Block a user