new pythonic implementation

pull/5/head
Thalles 2020-03-13 22:56:04 -03:00
parent 18b070f3ef
commit 45e3b3b7ef
11 changed files with 345 additions and 303 deletions

View File

@ -1,12 +1,19 @@
batch_size: 512
out_dim: 256
s: 1
temperature: 0.5
base_convnet: "resnet18" # one of: "resnet18 or resnet50"
use_cosine_similarity: True
epochs: 50
num_workers: 0
valid_size: 0.05
eval_every_n_epochs: 2
continue_training: Mar10_21-50-05_thallessilva
epochs: 33
eval_every_n_epochs: 1
fine_tune_from: 'Mar13_20-12-09_thallessilva'
log_every_n_steps: 50
model:
out_dim: 256
base_model: "resnet18"
dataset:
s: 1
input_shape: (96,96,3)
num_workers: 0
valid_size: 0.05
loss:
temperature: 0.5
use_cosine_similarity: True

View File

@ -1,46 +0,0 @@
import torchvision.transforms as transforms
import cv2
import numpy as np
class DataTransform(object):
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
def __call__(self, sample):
sample = np.array(sample)
# blur the image with a 50% chance
prob = np.random.random_sample()
if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
return sample
def get_simclr_data_transform(s, crop_size):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * crop_size)),
transforms.ToTensor()])
return data_transforms

View File

@ -0,0 +1,68 @@
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import datasets
np.random.seed(0)
class DataSetWrapper(object):
def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
self.batch_size = batch_size
self.num_workers = num_workers
self.valid_size = valid_size
self.s = s
self.input_shape = eval(input_shape)
def get_data_loaders(self):
data_augment = self._get_simclr_pipeline_transform()
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True,
transform=SimCLRDataTransform(data_augment))
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
return train_loader, valid_loader
def _get_simclr_pipeline_transform(self):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
transforms.ToTensor()])
return data_transforms
def get_train_validation_data_loaders(self, train_dataset):
# obtain training indices that will be used for validation
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(self.valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
num_workers=self.num_workers, drop_last=True, shuffle=False)
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
num_workers=self.num_workers, drop_last=True)
return train_loader, valid_loader
class SimCLRDataTransform(object):
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj

View File

@ -0,0 +1,25 @@
import cv2
import numpy as np
np.random.seed(0)
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
def __call__(self, sample):
sample = np.array(sample)
# blur the image with a 50% chance
prob = np.random.random_sample()
if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
return sample

View File

@ -33,7 +33,7 @@
"metadata": {},
"outputs": [],
"source": [
"folder_name = 'Mar10_21-50-05_thallessilva'\n",
"folder_name = 'Mar13_20-12-09_thallessilva'\n",
"checkpoints_folder = os.path.join('../runs', folder_name, 'checkpoints')\n",
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"), Loader=yaml.FullLoader)"
]
@ -52,10 +52,13 @@
" 'temperature': 0.5,\n",
" 'base_convnet': 'resnet18',\n",
" 'use_cosine_similarity': True,\n",
" 'epochs': 50,\n",
" 'num_workers': 4,\n",
" 'epochs': 40,\n",
" 'num_workers': 0,\n",
" 'valid_size': 0.05,\n",
" 'eval_every_n_epochs': 2}"
" 'eval_every_n_epochs': 2,\n",
" 'continue_training': 'Mar10_21-50-05_thallessilva',\n",
" 'log_every_n_steps': 50,\n",
" 'input_shape': '(96,96,3)'}"
]
},
"execution_count": 4,
@ -204,17 +207,19 @@
"output_type": "stream",
"text": [
"Logistic Regression feature eval\n",
"Train score: 0.4966\n",
"Test score: 0.35\n",
"Train score: 0.495\n",
"Test score: 0.34725\n",
"-------------------------------\n",
"KNN feature eval\n",
"Train score: 0.4036\n",
"Test score: 0.300125\n"
"Train score: 0.406\n",
"Test score: 0.297875\n"
]
}
],
"source": [
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)"
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)\n",
"del X_train_pca\n",
"del X_test_pca"
]
},
{
@ -226,16 +231,23 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Feature extractor: resnet18\n"
]
},
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -250,7 +262,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -263,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"outputs": [
{
@ -290,7 +302,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"metadata": {},
"outputs": [
{
@ -317,34 +329,20 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Logistic Regression feature eval\n",
"Train score: 0.9628\n",
"Test score: 0.75\n",
"Train score: 0.829\n",
"Test score: 0.5445\n",
"-------------------------------\n",
"KNN feature eval\n",
"Train score: 0.7764\n",
"Test score: 0.709125\n"
"Train score: 0.5934\n",
"Test score: 0.47675\n"
]
}
],
@ -352,12 +350,15 @@
"scaler = preprocessing.StandardScaler()\n",
"scaler.fit(X_train_feature)\n",
"\n",
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)"
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)\n",
"\n",
"del X_train_feature\n",
"del X_test_feature"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [

62
loss/nt_xent.py 100644
View File

@ -0,0 +1,62 @@
import torch
import numpy as np
class NTXentLoss(torch.nn.Module):
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
super(NTXentLoss, self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.device = device
self.softmax = torch.nn.Softmax(dim=-1)
self.mask_samples_from_same_repr = self._get_correlated_mask()
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
self.labels = self._get_labels()
def _get_similarity_function(self, use_cosine_similarity):
if use_cosine_similarity:
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
return self._cosine_simililarity
else:
return self._dot_simililarity
def _get_labels(self):
labels = (np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=-self.batch_size) + np.eye(
(2 * self.batch_size),
2 * self.batch_size - 1,
k=self.batch_size - 1)).astype(np.int)
labels = torch.from_numpy(labels)
labels = labels.to(self.device)
return labels
def _get_correlated_mask(self):
mask_samples_from_same_repr = (1 - torch.eye(2 * self.batch_size)).type(torch.bool)
return mask_samples_from_same_repr
def _dot_simililarity(self, x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v
def _cosine_simililarity(self, x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v
def forward(self, zis, zjs):
negatives = torch.cat([zjs, zis], dim=0)
logits = self.similarity_function(negatives, negatives)
logits = logits[self.mask_samples_from_same_repr.type(torch.bool)].view(2 * self.batch_size, -1)
logits /= self.temperature
assert logits.shape == (2 * self.batch_size, 2 * self.batch_size - 1), "Shape of negatives not expected." + str(
logits.shape)
probs = self.softmax(logits)
loss = torch.mean(-torch.sum(self.labels * torch.log(probs), dim=-1))
return loss

View File

@ -1,4 +1,3 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
@ -6,7 +5,7 @@ import torchvision.models as models
class ResNetSimCLR(nn.Module):
def __init__(self, base_model="resnet18", out_dim=64):
def __init__(self, base_model, out_dim):
super(ResNetSimCLR, self).__init__()
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
"resnet50": models.resnet50(pretrained=False)}
@ -22,7 +21,9 @@ class ResNetSimCLR(nn.Module):
def _get_basemodel(self, model_name):
try:
return self.resnet_dict[model_name]
model = self.resnet_dict[model_name]
print("Feature extractor:", model_name)
return model
except:
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")

13
run.py 100644
View File

@ -0,0 +1,13 @@
from simclr import SimCLR
import yaml
def main():
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
simclr = SimCLR(config)
simclr.train()
if __name__ == "__main__":
main()

120
simclr.py 100644
View File

@ -0,0 +1,120 @@
import torch
from models.resnet_simclr import ResNetSimCLR
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from loss.nt_xent import NTXentLoss
import os
import shutil
from data_aug.dataset_wrapper import DataSetWrapper
import numpy as np
torch.manual_seed(0)
class SimCLR(object):
def __init__(self, config):
self.config = config
self.device = self._get_device()
self.writer = SummaryWriter()
self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
def _get_device(self):
device = 'cuda' if torch.cuda.is_available() else 'gpu'
print("Running on:", device)
return device
def _step(self, model, xis, xjs, n_iter):
# get the representations and the projections
ris, zis = model(xis) # [N,C]
# get the representations and the projections
rjs, zjs = model(xjs) # [N,C]
# normalize projection feature vectors
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
loss = self.nt_xent_criterion(zis, zjs)
if n_iter % self.config['log_every_n_steps'] == 0:
self.writer.add_histogram("xi_repr", ris, global_step=n_iter)
self.writer.add_histogram("xi_latent", zis, global_step=n_iter)
self.writer.add_histogram("xj_repr", rjs, global_step=n_iter)
self.writer.add_histogram("xj_latent", zjs, global_step=n_iter)
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
return loss
def train(self):
dataset = DataSetWrapper(self.config['batch_size'], **self.config['dataset'])
train_loader, valid_loader = dataset.get_data_loaders()
model = ResNetSimCLR(**self.config["model"]).to(self.device)
model = self._load_pre_trained_weights(model)
optimizer = torch.optim.Adam(model.parameters(), 3e-4)
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
# save config file
self._save_config_file(model_checkpoints_folder)
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf
for epoch_counter in range(self.config['epochs']):
for (xis, xjs), _ in train_loader:
optimizer.zero_grad()
xis = xis.to(self.device)
xjs = xjs.to(self.device)
loss = self._step(model, xis, xjs, n_iter)
loss.backward()
optimizer.step()
n_iter += 1
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
# validation steps
with torch.no_grad():
model.eval()
valid_loss = 0.0
for counter, ((xis, xjs), _) in enumerate(valid_loader):
xis = xis.to(self.device)
xjs = xjs.to(self.device)
loss = self._step(model, xis, xjs, n_iter)
valid_loss += loss.item()
valid_loss /= counter
if valid_loss < best_valid_loss:
# save the model weights
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
valid_n_iter += 1
model.train()
def _save_config_file(self, model_checkpoints_folder):
if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
def _load_pre_trained_weights(self, model):
try:
checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
model.load_state_dict(state_dict)
print("Loaded pre-trained model with success.")
except FileNotFoundError:
print("Pre-trained weights not found. Training from scratch.")
return model

147
train.py
View File

@ -1,147 +0,0 @@
import shutil
import torch
import yaml
print(torch.__version__)
import torch.optim as optim
import os
from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import numpy as np
from models.resnet_simclr import ResNetSimCLR
from utils import get_similarity_function, get_train_validation_data_loaders
from data_aug.data_transform import DataTransform, get_simclr_data_transform
torch.manual_seed(0)
np.random.seed(0)
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
batch_size = config['batch_size']
out_dim = config['out_dim']
temperature = config['temperature']
use_cosine_similarity = config['use_cosine_similarity']
data_augment = get_simclr_data_transform(s=config['s'], crop_size=96)
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=DataTransform(data_augment))
train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, **config)
# model = Encoder(out_dim=out_dim)
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
if config['continue_training']:
checkpoints_folder = os.path.join('./runs', config['continue_training'], 'checkpoints')
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
model.load_state_dict(state_dict)
print("Loaded pre-trained model with success.")
train_gpu = torch.cuda.is_available()
print("Is gpu available:", train_gpu)
# moves the model parameters to gpu
if train_gpu:
model = model.cuda()
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), 3e-4)
train_writer = SummaryWriter()
similarity_func = get_similarity_function(use_cosine_similarity)
megative_mask = (1 - torch.eye(2 * batch_size)).type(torch.bool)
labels = (np.eye((2 * batch_size), 2 * batch_size - 1, k=-batch_size) + np.eye((2 * batch_size), 2 * batch_size - 1,
k=batch_size - 1)).astype(np.int)
labels = torch.from_numpy(labels)
softmax = torch.nn.Softmax(dim=-1)
if train_gpu:
labels = labels.cuda()
def step(xis, xjs):
# get the representations and the projections
ris, zis = model(xis) # [N,C]
# get the representations and the projections
rjs, zjs = model(xjs) # [N,C]
if n_iter % config['log_every_n_steps'] == 0:
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
# normalize projection feature vectors
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
negatives = torch.cat([zjs, zis], dim=0)
logits = similarity_func(negatives, negatives)
logits = logits[megative_mask.type(torch.bool)].view(2 * batch_size, -1)
logits /= temperature
# assert logits.shape == (2 * batch_size, 2 * batch_size - 1), "Shape of negatives not expected." + str(
# logits.shape)
probs = softmax(logits)
loss = torch.mean(-torch.sum(labels * torch.log(probs), dim=-1))
return loss
model_checkpoints_folder = os.path.join(train_writer.log_dir, 'checkpoints')
if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf
for epoch_counter in range(config['epochs']):
for (xis, xjs), _ in train_loader:
optimizer.zero_grad()
if train_gpu:
xis = xis.cuda()
xjs = xjs.cuda()
loss = step(xis, xjs)
train_writer.add_scalar('train_loss', loss, global_step=n_iter)
loss.backward()
optimizer.step()
n_iter += 1
if epoch_counter % config['eval_every_n_epochs'] == 0:
# validation steps
with torch.no_grad():
model.eval()
valid_loss = 0.0
for counter, ((xis, xjs), _) in enumerate(valid_loader):
if train_gpu:
xis = xis.cuda()
xjs = xjs.cuda()
loss = (step(xis, xjs))
valid_loss += loss.item()
valid_loss /= counter
if valid_loss < best_valid_loss:
# save the model weights
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
valid_n_iter += 1
model.train()

View File

@ -1,62 +0,0 @@
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
np.random.seed(0)
cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
def get_train_validation_data_loaders(train_dataset, batch_size, num_workers, valid_size, **ignored):
# obtain training indices that will be used for validation
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, drop_last=True, shuffle=False)
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, drop_last=True)
return train_loader, valid_loader
def get_negative_mask(batch_size):
# return a mask that removes the similarity score of equal/similar images.
# this function ensures that only distinct pair of images get their similarity scores
# passed as negative examples
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
for i in range(batch_size):
negative_mask[i, i] = 0
negative_mask[i, i + batch_size] = 0
return negative_mask
def _dot_simililarity_dim2(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v
def _cosine_simililarity_dim2(x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v
def get_similarity_function(use_cosine_similarity):
if use_cosine_similarity:
return _cosine_simililarity_dim2
else:
return _dot_simililarity_dim2