mirror of https://github.com/sthalles/SimCLR.git
support to use other pytorch datasets
parent
2b6bfd9933
commit
88dcdf6d06
|
@ -0,0 +1,46 @@
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class DataTransform(object):
|
||||||
|
def __init__(self, transform):
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
xi = self.transform(sample)
|
||||||
|
xj = self.transform(sample)
|
||||||
|
return xi, xj
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianBlur(object):
|
||||||
|
# Implements Gaussian blur as described in the SimCLR paper
|
||||||
|
def __init__(self, kernel_size, min=0.1, max=2.0):
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
# kernel size is set to be 10% of the image height/width
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
sample = np.array(sample)
|
||||||
|
|
||||||
|
# blur the image with a 50% chance
|
||||||
|
prob = np.random.random_sample()
|
||||||
|
|
||||||
|
if prob < 0.5:
|
||||||
|
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||||
|
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_transform_opes(s, crop_size):
|
||||||
|
# get a set of data augmentation transformations as described in the SimCLR paper.
|
||||||
|
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
||||||
|
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomApply([color_jitter], p=0.8),
|
||||||
|
transforms.RandomGrayscale(p=0.2),
|
||||||
|
GaussianBlur(kernel_size=int(0.1 * crop_size)),
|
||||||
|
transforms.ToTensor()])
|
||||||
|
return data_transforms
|
|
@ -2,7 +2,7 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 17,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -13,7 +13,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 18,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -26,17 +26,17 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 19,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"batch_size = 256\n",
|
"batch_size = 256\n",
|
||||||
"out_dim = 64"
|
"out_dim = 128"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 20,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -54,7 +54,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 21,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -74,7 +74,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 22,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -101,7 +101,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 23,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -112,7 +112,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 24,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -120,8 +120,8 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"PCA features\n",
|
"PCA features\n",
|
||||||
"(5000, 64)\n",
|
"(5000, 128)\n",
|
||||||
"(8000, 64)\n"
|
"(8000, 128)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -129,7 +129,7 @@
|
||||||
"scaler = preprocessing.StandardScaler()\n",
|
"scaler = preprocessing.StandardScaler()\n",
|
||||||
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
|
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pca = PCA(n_components=64)\n",
|
"pca = PCA(n_components=128)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
|
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
|
||||||
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
|
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
|
||||||
|
@ -141,7 +141,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 25,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -149,8 +149,8 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"PCA feature evaluation\n",
|
"PCA feature evaluation\n",
|
||||||
"Train score: 0.396\n",
|
"Train score: 0.4306\n",
|
||||||
"Test score: 0.3565\n"
|
"Test score: 0.3625\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -164,7 +164,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 26,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -185,7 +185,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 27,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -203,7 +203,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 28,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -295,7 +295,7 @@
|
||||||
" (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
" (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" (l1): Linear(in_features=512, out_features=512, bias=True)\n",
|
" (l1): Linear(in_features=512, out_features=512, bias=True)\n",
|
||||||
" (l2): Linear(in_features=512, out_features=64, bias=True)\n",
|
" (l2): Linear(in_features=512, out_features=128, bias=True)\n",
|
||||||
")\n",
|
")\n",
|
||||||
"odict_keys(['features.0.weight', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.4.0.conv1.weight', 'features.4.0.bn1.weight', 'features.4.0.bn1.bias', 'features.4.0.bn1.running_mean', 'features.4.0.bn1.running_var', 'features.4.0.bn1.num_batches_tracked', 'features.4.0.conv2.weight', 'features.4.0.bn2.weight', 'features.4.0.bn2.bias', 'features.4.0.bn2.running_mean', 'features.4.0.bn2.running_var', 'features.4.0.bn2.num_batches_tracked', 'features.4.1.conv1.weight', 'features.4.1.bn1.weight', 'features.4.1.bn1.bias', 'features.4.1.bn1.running_mean', 'features.4.1.bn1.running_var', 'features.4.1.bn1.num_batches_tracked', 'features.4.1.conv2.weight', 'features.4.1.bn2.weight', 'features.4.1.bn2.bias', 'features.4.1.bn2.running_mean', 'features.4.1.bn2.running_var', 'features.4.1.bn2.num_batches_tracked', 'features.5.0.conv1.weight', 'features.5.0.bn1.weight', 'features.5.0.bn1.bias', 'features.5.0.bn1.running_mean', 'features.5.0.bn1.running_var', 'features.5.0.bn1.num_batches_tracked', 'features.5.0.conv2.weight', 'features.5.0.bn2.weight', 'features.5.0.bn2.bias', 'features.5.0.bn2.running_mean', 'features.5.0.bn2.running_var', 'features.5.0.bn2.num_batches_tracked', 'features.5.0.downsample.0.weight', 'features.5.0.downsample.1.weight', 'features.5.0.downsample.1.bias', 'features.5.0.downsample.1.running_mean', 'features.5.0.downsample.1.running_var', 'features.5.0.downsample.1.num_batches_tracked', 'features.5.1.conv1.weight', 'features.5.1.bn1.weight', 'features.5.1.bn1.bias', 'features.5.1.bn1.running_mean', 'features.5.1.bn1.running_var', 'features.5.1.bn1.num_batches_tracked', 'features.5.1.conv2.weight', 'features.5.1.bn2.weight', 'features.5.1.bn2.bias', 'features.5.1.bn2.running_mean', 'features.5.1.bn2.running_var', 'features.5.1.bn2.num_batches_tracked', 'features.6.0.conv1.weight', 'features.6.0.bn1.weight', 'features.6.0.bn1.bias', 'features.6.0.bn1.running_mean', 'features.6.0.bn1.running_var', 'features.6.0.bn1.num_batches_tracked', 'features.6.0.conv2.weight', 'features.6.0.bn2.weight', 'features.6.0.bn2.bias', 'features.6.0.bn2.running_mean', 'features.6.0.bn2.running_var', 'features.6.0.bn2.num_batches_tracked', 'features.6.0.downsample.0.weight', 'features.6.0.downsample.1.weight', 'features.6.0.downsample.1.bias', 'features.6.0.downsample.1.running_mean', 'features.6.0.downsample.1.running_var', 'features.6.0.downsample.1.num_batches_tracked', 'features.6.1.conv1.weight', 'features.6.1.bn1.weight', 'features.6.1.bn1.bias', 'features.6.1.bn1.running_mean', 'features.6.1.bn1.running_var', 'features.6.1.bn1.num_batches_tracked', 'features.6.1.conv2.weight', 'features.6.1.bn2.weight', 'features.6.1.bn2.bias', 'features.6.1.bn2.running_mean', 'features.6.1.bn2.running_var', 'features.6.1.bn2.num_batches_tracked', 'features.7.0.conv1.weight', 'features.7.0.bn1.weight', 'features.7.0.bn1.bias', 'features.7.0.bn1.running_mean', 'features.7.0.bn1.running_var', 'features.7.0.bn1.num_batches_tracked', 'features.7.0.conv2.weight', 'features.7.0.bn2.weight', 'features.7.0.bn2.bias', 'features.7.0.bn2.running_mean', 'features.7.0.bn2.running_var', 'features.7.0.bn2.num_batches_tracked', 'features.7.0.downsample.0.weight', 'features.7.0.downsample.1.weight', 'features.7.0.downsample.1.bias', 'features.7.0.downsample.1.running_mean', 'features.7.0.downsample.1.running_var', 'features.7.0.downsample.1.num_batches_tracked', 'features.7.1.conv1.weight', 'features.7.1.bn1.weight', 'features.7.1.bn1.bias', 'features.7.1.bn1.running_mean', 'features.7.1.bn1.running_var', 'features.7.1.bn1.num_batches_tracked', 'features.7.1.conv2.weight', 'features.7.1.bn2.weight', 'features.7.1.bn2.bias', 'features.7.1.bn2.running_mean', 'features.7.1.bn2.running_var', 'features.7.1.bn2.num_batches_tracked', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])\n"
|
"odict_keys(['features.0.weight', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.4.0.conv1.weight', 'features.4.0.bn1.weight', 'features.4.0.bn1.bias', 'features.4.0.bn1.running_mean', 'features.4.0.bn1.running_var', 'features.4.0.bn1.num_batches_tracked', 'features.4.0.conv2.weight', 'features.4.0.bn2.weight', 'features.4.0.bn2.bias', 'features.4.0.bn2.running_mean', 'features.4.0.bn2.running_var', 'features.4.0.bn2.num_batches_tracked', 'features.4.1.conv1.weight', 'features.4.1.bn1.weight', 'features.4.1.bn1.bias', 'features.4.1.bn1.running_mean', 'features.4.1.bn1.running_var', 'features.4.1.bn1.num_batches_tracked', 'features.4.1.conv2.weight', 'features.4.1.bn2.weight', 'features.4.1.bn2.bias', 'features.4.1.bn2.running_mean', 'features.4.1.bn2.running_var', 'features.4.1.bn2.num_batches_tracked', 'features.5.0.conv1.weight', 'features.5.0.bn1.weight', 'features.5.0.bn1.bias', 'features.5.0.bn1.running_mean', 'features.5.0.bn1.running_var', 'features.5.0.bn1.num_batches_tracked', 'features.5.0.conv2.weight', 'features.5.0.bn2.weight', 'features.5.0.bn2.bias', 'features.5.0.bn2.running_mean', 'features.5.0.bn2.running_var', 'features.5.0.bn2.num_batches_tracked', 'features.5.0.downsample.0.weight', 'features.5.0.downsample.1.weight', 'features.5.0.downsample.1.bias', 'features.5.0.downsample.1.running_mean', 'features.5.0.downsample.1.running_var', 'features.5.0.downsample.1.num_batches_tracked', 'features.5.1.conv1.weight', 'features.5.1.bn1.weight', 'features.5.1.bn1.bias', 'features.5.1.bn1.running_mean', 'features.5.1.bn1.running_var', 'features.5.1.bn1.num_batches_tracked', 'features.5.1.conv2.weight', 'features.5.1.bn2.weight', 'features.5.1.bn2.bias', 'features.5.1.bn2.running_mean', 'features.5.1.bn2.running_var', 'features.5.1.bn2.num_batches_tracked', 'features.6.0.conv1.weight', 'features.6.0.bn1.weight', 'features.6.0.bn1.bias', 'features.6.0.bn1.running_mean', 'features.6.0.bn1.running_var', 'features.6.0.bn1.num_batches_tracked', 'features.6.0.conv2.weight', 'features.6.0.bn2.weight', 'features.6.0.bn2.bias', 'features.6.0.bn2.running_mean', 'features.6.0.bn2.running_var', 'features.6.0.bn2.num_batches_tracked', 'features.6.0.downsample.0.weight', 'features.6.0.downsample.1.weight', 'features.6.0.downsample.1.bias', 'features.6.0.downsample.1.running_mean', 'features.6.0.downsample.1.running_var', 'features.6.0.downsample.1.num_batches_tracked', 'features.6.1.conv1.weight', 'features.6.1.bn1.weight', 'features.6.1.bn1.bias', 'features.6.1.bn1.running_mean', 'features.6.1.bn1.running_var', 'features.6.1.bn1.num_batches_tracked', 'features.6.1.conv2.weight', 'features.6.1.bn2.weight', 'features.6.1.bn2.bias', 'features.6.1.bn2.running_mean', 'features.6.1.bn2.running_var', 'features.6.1.bn2.num_batches_tracked', 'features.7.0.conv1.weight', 'features.7.0.bn1.weight', 'features.7.0.bn1.bias', 'features.7.0.bn1.running_mean', 'features.7.0.bn1.running_var', 'features.7.0.bn1.num_batches_tracked', 'features.7.0.conv2.weight', 'features.7.0.bn2.weight', 'features.7.0.bn2.bias', 'features.7.0.bn2.running_mean', 'features.7.0.bn2.running_var', 'features.7.0.bn2.num_batches_tracked', 'features.7.0.downsample.0.weight', 'features.7.0.downsample.1.weight', 'features.7.0.downsample.1.bias', 'features.7.0.downsample.1.running_mean', 'features.7.0.downsample.1.running_var', 'features.7.0.downsample.1.num_batches_tracked', 'features.7.1.conv1.weight', 'features.7.1.bn1.weight', 'features.7.1.bn1.bias', 'features.7.1.bn1.running_mean', 'features.7.1.bn1.running_var', 'features.7.1.bn1.num_batches_tracked', 'features.7.1.conv2.weight', 'features.7.1.bn2.weight', 'features.7.1.bn2.bias', 'features.7.1.bn2.running_mean', 'features.7.1.bn2.running_var', 'features.7.1.bn2.num_batches_tracked', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])\n"
|
||||||
]
|
]
|
||||||
|
@ -306,7 +306,7 @@
|
||||||
"<All keys matched successfully>"
|
"<All keys matched successfully>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 12,
|
"execution_count": 28,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
|
@ -331,7 +331,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 29,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -358,7 +358,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 30,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -385,7 +385,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 31,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -398,7 +398,7 @@
|
||||||
" warm_start=False)"
|
" warm_start=False)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 15,
|
"execution_count": 31,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
|
@ -414,7 +414,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 32,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -422,8 +422,8 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"SimCLR feature evaluation\n",
|
"SimCLR feature evaluation\n",
|
||||||
"Train score: 0.8948\n",
|
"Train score: 0.8914\n",
|
||||||
"Test score: 0.639625\n"
|
"Test score: 0.6425\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
24
train.py
24
train.py
|
@ -8,9 +8,10 @@ from torch.utils.data import DataLoader
|
||||||
from torchvision import datasets
|
from torchvision import datasets
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
from models.resnet_simclr import ResNetSimCLR
|
from models.resnet_simclr import ResNetSimCLR
|
||||||
from utils import get_negative_mask, get_augmentation_transform, get_similarity_function
|
from utils import get_negative_mask, get_similarity_function
|
||||||
|
from data_aug.data_transform import DataTransform, get_data_transform_opes
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
@ -21,9 +22,9 @@ out_dim = config['out_dim']
|
||||||
temperature = config['temperature']
|
temperature = config['temperature']
|
||||||
use_cosine_similarity = config['use_cosine_similarity']
|
use_cosine_similarity = config['use_cosine_similarity']
|
||||||
|
|
||||||
data_augment = get_augmentation_transform(s=config['s'], crop_size=96)
|
data_augment = get_data_transform_opes(s=config['s'], crop_size=96)
|
||||||
|
|
||||||
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=transforms.ToTensor())
|
train_dataset = datasets.STL10('./data', split='train', download=True, transform=DataTransform(data_augment))
|
||||||
# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(),
|
# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(),
|
||||||
# target_transform=None, download=True)
|
# target_transform=None, download=True)
|
||||||
|
|
||||||
|
@ -52,20 +53,7 @@ negative_mask = get_negative_mask(batch_size)
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
for e in range(config['epochs']):
|
for e in range(config['epochs']):
|
||||||
for step, (batch_x, _) in enumerate(train_loader):
|
for step, ((xis, xjs), _) in enumerate(train_loader):
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
xis = []
|
|
||||||
xjs = []
|
|
||||||
|
|
||||||
# draw two augmentation functions t , t' and apply separately for each input example
|
|
||||||
for k in range(len(batch_x)):
|
|
||||||
xis.append(data_augment(batch_x[k])) # the first augmentation
|
|
||||||
xjs.append(data_augment(batch_x[k])) # the second augmentation
|
|
||||||
|
|
||||||
xis = torch.stack(xis)
|
|
||||||
xjs = torch.stack(xjs)
|
|
||||||
|
|
||||||
if train_gpu:
|
if train_gpu:
|
||||||
xis = xis.cuda()
|
xis = xis.cuda()
|
||||||
|
|
35
utils.py
35
utils.py
|
@ -2,6 +2,7 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
cos1d = torch.nn.CosineSimilarity(dim=1)
|
cos1d = torch.nn.CosineSimilarity(dim=1)
|
||||||
|
@ -19,40 +20,6 @@ def get_negative_mask(batch_size):
|
||||||
return negative_mask
|
return negative_mask
|
||||||
|
|
||||||
|
|
||||||
class GaussianBlur(object):
|
|
||||||
# Implements Gaussian blur as described in the SimCLR paper
|
|
||||||
def __init__(self, kernel_size, min=0.1, max=2.0):
|
|
||||||
self.min = min
|
|
||||||
self.max = max
|
|
||||||
# kernel size is set to be 10% of the image height/width
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
sample = np.array(sample)
|
|
||||||
|
|
||||||
# blur the image with a 50% chance
|
|
||||||
prob = np.random.random_sample()
|
|
||||||
|
|
||||||
if prob < 0.5:
|
|
||||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
|
||||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
def get_augmentation_transform(s, crop_size):
|
|
||||||
# get a set of data augmentation transformations as described in the SimCLR paper.
|
|
||||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
|
||||||
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
|
|
||||||
transforms.RandomResizedCrop(size=crop_size),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomApply([color_jitter], p=0.8),
|
|
||||||
transforms.RandomGrayscale(p=0.2),
|
|
||||||
GaussianBlur(kernel_size=int(0.1 * crop_size)),
|
|
||||||
transforms.ToTensor()])
|
|
||||||
return data_aug_ope
|
|
||||||
|
|
||||||
|
|
||||||
def _dot_simililarity_dim1(x, y):
|
def _dot_simililarity_dim1(x, y):
|
||||||
# x shape: (N, 1, C)
|
# x shape: (N, 1, C)
|
||||||
# y shape: (N, C, 1)
|
# y shape: (N, C, 1)
|
||||||
|
|
Loading…
Reference in New Issue