mirror of https://github.com/sthalles/SimCLR.git
support to use other pytorch datasets
parent
2b6bfd9933
commit
88dcdf6d06
|
@ -0,0 +1,46 @@
|
|||
import torchvision.transforms as transforms
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataTransform(object):
|
||||
def __init__(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def __call__(self, sample):
|
||||
xi = self.transform(sample)
|
||||
xj = self.transform(sample)
|
||||
return xi, xj
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
# Implements Gaussian blur as described in the SimCLR paper
|
||||
def __init__(self, kernel_size, min=0.1, max=2.0):
|
||||
self.min = min
|
||||
self.max = max
|
||||
# kernel size is set to be 10% of the image height/width
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def __call__(self, sample):
|
||||
sample = np.array(sample)
|
||||
|
||||
# blur the image with a 50% chance
|
||||
prob = np.random.random_sample()
|
||||
|
||||
if prob < 0.5:
|
||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def get_data_transform_opes(s, crop_size):
|
||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
||||
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([color_jitter], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
GaussianBlur(kernel_size=int(0.1 * crop_size)),
|
||||
transforms.ToTensor()])
|
||||
return data_transforms
|
|
@ -2,7 +2,7 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -13,7 +13,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -26,17 +26,17 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_size = 256\n",
|
||||
"out_dim = 64"
|
||||
"out_dim = 128"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -54,7 +54,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -74,7 +74,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -101,7 +101,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -112,7 +112,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -120,8 +120,8 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"PCA features\n",
|
||||
"(5000, 64)\n",
|
||||
"(8000, 64)\n"
|
||||
"(5000, 128)\n",
|
||||
"(8000, 128)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -129,7 +129,7 @@
|
|||
"scaler = preprocessing.StandardScaler()\n",
|
||||
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
|
||||
"\n",
|
||||
"pca = PCA(n_components=64)\n",
|
||||
"pca = PCA(n_components=128)\n",
|
||||
"\n",
|
||||
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
|
||||
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
|
||||
|
@ -141,7 +141,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -149,8 +149,8 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"PCA feature evaluation\n",
|
||||
"Train score: 0.396\n",
|
||||
"Test score: 0.3565\n"
|
||||
"Train score: 0.4306\n",
|
||||
"Test score: 0.3625\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -164,7 +164,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -185,7 +185,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -203,7 +203,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -295,7 +295,7 @@
|
|||
" (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
||||
" )\n",
|
||||
" (l1): Linear(in_features=512, out_features=512, bias=True)\n",
|
||||
" (l2): Linear(in_features=512, out_features=64, bias=True)\n",
|
||||
" (l2): Linear(in_features=512, out_features=128, bias=True)\n",
|
||||
")\n",
|
||||
"odict_keys(['features.0.weight', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.4.0.conv1.weight', 'features.4.0.bn1.weight', 'features.4.0.bn1.bias', 'features.4.0.bn1.running_mean', 'features.4.0.bn1.running_var', 'features.4.0.bn1.num_batches_tracked', 'features.4.0.conv2.weight', 'features.4.0.bn2.weight', 'features.4.0.bn2.bias', 'features.4.0.bn2.running_mean', 'features.4.0.bn2.running_var', 'features.4.0.bn2.num_batches_tracked', 'features.4.1.conv1.weight', 'features.4.1.bn1.weight', 'features.4.1.bn1.bias', 'features.4.1.bn1.running_mean', 'features.4.1.bn1.running_var', 'features.4.1.bn1.num_batches_tracked', 'features.4.1.conv2.weight', 'features.4.1.bn2.weight', 'features.4.1.bn2.bias', 'features.4.1.bn2.running_mean', 'features.4.1.bn2.running_var', 'features.4.1.bn2.num_batches_tracked', 'features.5.0.conv1.weight', 'features.5.0.bn1.weight', 'features.5.0.bn1.bias', 'features.5.0.bn1.running_mean', 'features.5.0.bn1.running_var', 'features.5.0.bn1.num_batches_tracked', 'features.5.0.conv2.weight', 'features.5.0.bn2.weight', 'features.5.0.bn2.bias', 'features.5.0.bn2.running_mean', 'features.5.0.bn2.running_var', 'features.5.0.bn2.num_batches_tracked', 'features.5.0.downsample.0.weight', 'features.5.0.downsample.1.weight', 'features.5.0.downsample.1.bias', 'features.5.0.downsample.1.running_mean', 'features.5.0.downsample.1.running_var', 'features.5.0.downsample.1.num_batches_tracked', 'features.5.1.conv1.weight', 'features.5.1.bn1.weight', 'features.5.1.bn1.bias', 'features.5.1.bn1.running_mean', 'features.5.1.bn1.running_var', 'features.5.1.bn1.num_batches_tracked', 'features.5.1.conv2.weight', 'features.5.1.bn2.weight', 'features.5.1.bn2.bias', 'features.5.1.bn2.running_mean', 'features.5.1.bn2.running_var', 'features.5.1.bn2.num_batches_tracked', 'features.6.0.conv1.weight', 'features.6.0.bn1.weight', 'features.6.0.bn1.bias', 'features.6.0.bn1.running_mean', 'features.6.0.bn1.running_var', 'features.6.0.bn1.num_batches_tracked', 'features.6.0.conv2.weight', 'features.6.0.bn2.weight', 'features.6.0.bn2.bias', 'features.6.0.bn2.running_mean', 'features.6.0.bn2.running_var', 'features.6.0.bn2.num_batches_tracked', 'features.6.0.downsample.0.weight', 'features.6.0.downsample.1.weight', 'features.6.0.downsample.1.bias', 'features.6.0.downsample.1.running_mean', 'features.6.0.downsample.1.running_var', 'features.6.0.downsample.1.num_batches_tracked', 'features.6.1.conv1.weight', 'features.6.1.bn1.weight', 'features.6.1.bn1.bias', 'features.6.1.bn1.running_mean', 'features.6.1.bn1.running_var', 'features.6.1.bn1.num_batches_tracked', 'features.6.1.conv2.weight', 'features.6.1.bn2.weight', 'features.6.1.bn2.bias', 'features.6.1.bn2.running_mean', 'features.6.1.bn2.running_var', 'features.6.1.bn2.num_batches_tracked', 'features.7.0.conv1.weight', 'features.7.0.bn1.weight', 'features.7.0.bn1.bias', 'features.7.0.bn1.running_mean', 'features.7.0.bn1.running_var', 'features.7.0.bn1.num_batches_tracked', 'features.7.0.conv2.weight', 'features.7.0.bn2.weight', 'features.7.0.bn2.bias', 'features.7.0.bn2.running_mean', 'features.7.0.bn2.running_var', 'features.7.0.bn2.num_batches_tracked', 'features.7.0.downsample.0.weight', 'features.7.0.downsample.1.weight', 'features.7.0.downsample.1.bias', 'features.7.0.downsample.1.running_mean', 'features.7.0.downsample.1.running_var', 'features.7.0.downsample.1.num_batches_tracked', 'features.7.1.conv1.weight', 'features.7.1.bn1.weight', 'features.7.1.bn1.bias', 'features.7.1.bn1.running_mean', 'features.7.1.bn1.running_var', 'features.7.1.bn1.num_batches_tracked', 'features.7.1.conv2.weight', 'features.7.1.bn2.weight', 'features.7.1.bn2.bias', 'features.7.1.bn2.running_mean', 'features.7.1.bn2.running_var', 'features.7.1.bn2.num_batches_tracked', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])\n"
|
||||
]
|
||||
|
@ -306,7 +306,7 @@
|
|||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -331,7 +331,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -358,7 +358,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -385,7 +385,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -398,7 +398,7 @@
|
|||
" warm_start=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -414,7 +414,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -422,8 +422,8 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"SimCLR feature evaluation\n",
|
||||
"Train score: 0.8948\n",
|
||||
"Test score: 0.639625\n"
|
||||
"Train score: 0.8914\n",
|
||||
"Test score: 0.6425\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
24
train.py
24
train.py
|
@ -8,9 +8,10 @@ from torch.utils.data import DataLoader
|
|||
from torchvision import datasets
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from models.resnet_simclr import ResNetSimCLR
|
||||
from utils import get_negative_mask, get_augmentation_transform, get_similarity_function
|
||||
from utils import get_negative_mask, get_similarity_function
|
||||
from data_aug.data_transform import DataTransform, get_data_transform_opes
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
@ -21,9 +22,9 @@ out_dim = config['out_dim']
|
|||
temperature = config['temperature']
|
||||
use_cosine_similarity = config['use_cosine_similarity']
|
||||
|
||||
data_augment = get_augmentation_transform(s=config['s'], crop_size=96)
|
||||
data_augment = get_data_transform_opes(s=config['s'], crop_size=96)
|
||||
|
||||
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=transforms.ToTensor())
|
||||
train_dataset = datasets.STL10('./data', split='train', download=True, transform=DataTransform(data_augment))
|
||||
# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(),
|
||||
# target_transform=None, download=True)
|
||||
|
||||
|
@ -52,20 +53,7 @@ negative_mask = get_negative_mask(batch_size)
|
|||
|
||||
n_iter = 0
|
||||
for e in range(config['epochs']):
|
||||
for step, (batch_x, _) in enumerate(train_loader):
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
xis = []
|
||||
xjs = []
|
||||
|
||||
# draw two augmentation functions t , t' and apply separately for each input example
|
||||
for k in range(len(batch_x)):
|
||||
xis.append(data_augment(batch_x[k])) # the first augmentation
|
||||
xjs.append(data_augment(batch_x[k])) # the second augmentation
|
||||
|
||||
xis = torch.stack(xis)
|
||||
xjs = torch.stack(xjs)
|
||||
for step, ((xis, xjs), _) in enumerate(train_loader):
|
||||
|
||||
if train_gpu:
|
||||
xis = xis.cuda()
|
||||
|
|
35
utils.py
35
utils.py
|
@ -2,6 +2,7 @@ import cv2
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
np.random.seed(0)
|
||||
cos1d = torch.nn.CosineSimilarity(dim=1)
|
||||
|
@ -19,40 +20,6 @@ def get_negative_mask(batch_size):
|
|||
return negative_mask
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
# Implements Gaussian blur as described in the SimCLR paper
|
||||
def __init__(self, kernel_size, min=0.1, max=2.0):
|
||||
self.min = min
|
||||
self.max = max
|
||||
# kernel size is set to be 10% of the image height/width
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def __call__(self, sample):
|
||||
sample = np.array(sample)
|
||||
|
||||
# blur the image with a 50% chance
|
||||
prob = np.random.random_sample()
|
||||
|
||||
if prob < 0.5:
|
||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def get_augmentation_transform(s, crop_size):
|
||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
||||
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
|
||||
transforms.RandomResizedCrop(size=crop_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([color_jitter], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
GaussianBlur(kernel_size=int(0.1 * crop_size)),
|
||||
transforms.ToTensor()])
|
||||
return data_aug_ope
|
||||
|
||||
|
||||
def _dot_simililarity_dim1(x, y):
|
||||
# x shape: (N, 1, C)
|
||||
# y shape: (N, C, 1)
|
||||
|
|
Loading…
Reference in New Issue