support to use other pytorch datasets

pull/5/head
Thalles 2020-02-29 19:32:37 -03:00
parent 2b6bfd9933
commit 88dcdf6d06
4 changed files with 80 additions and 79 deletions

View File

@ -0,0 +1,46 @@
import torchvision.transforms as transforms
import cv2
import numpy as np
class DataTransform(object):
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
def __call__(self, sample):
sample = np.array(sample)
# blur the image with a 50% chance
prob = np.random.random_sample()
if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
return sample
def get_data_transform_opes(s, crop_size):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * crop_size)),
transforms.ToTensor()])
return data_transforms

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@ -26,17 +26,17 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 256\n",
"out_dim = 64"
"out_dim = 128"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -54,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 21,
"metadata": {},
"outputs": [
{
@ -74,7 +74,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 22,
"metadata": {},
"outputs": [
{
@ -101,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@ -112,7 +112,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 24,
"metadata": {},
"outputs": [
{
@ -120,8 +120,8 @@
"output_type": "stream",
"text": [
"PCA features\n",
"(5000, 64)\n",
"(8000, 64)\n"
"(5000, 128)\n",
"(8000, 128)\n"
]
}
],
@ -129,7 +129,7 @@
"scaler = preprocessing.StandardScaler()\n",
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
"\n",
"pca = PCA(n_components=64)\n",
"pca = PCA(n_components=128)\n",
"\n",
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
@ -141,7 +141,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 25,
"metadata": {},
"outputs": [
{
@ -149,8 +149,8 @@
"output_type": "stream",
"text": [
"PCA feature evaluation\n",
"Train score: 0.396\n",
"Test score: 0.3565\n"
"Train score: 0.4306\n",
"Test score: 0.3625\n"
]
}
],
@ -164,7 +164,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 26,
"metadata": {},
"outputs": [
{
@ -185,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 27,
"metadata": {},
"outputs": [
{
@ -203,7 +203,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 28,
"metadata": {},
"outputs": [
{
@ -295,7 +295,7 @@
" (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" )\n",
" (l1): Linear(in_features=512, out_features=512, bias=True)\n",
" (l2): Linear(in_features=512, out_features=64, bias=True)\n",
" (l2): Linear(in_features=512, out_features=128, bias=True)\n",
")\n",
"odict_keys(['features.0.weight', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.4.0.conv1.weight', 'features.4.0.bn1.weight', 'features.4.0.bn1.bias', 'features.4.0.bn1.running_mean', 'features.4.0.bn1.running_var', 'features.4.0.bn1.num_batches_tracked', 'features.4.0.conv2.weight', 'features.4.0.bn2.weight', 'features.4.0.bn2.bias', 'features.4.0.bn2.running_mean', 'features.4.0.bn2.running_var', 'features.4.0.bn2.num_batches_tracked', 'features.4.1.conv1.weight', 'features.4.1.bn1.weight', 'features.4.1.bn1.bias', 'features.4.1.bn1.running_mean', 'features.4.1.bn1.running_var', 'features.4.1.bn1.num_batches_tracked', 'features.4.1.conv2.weight', 'features.4.1.bn2.weight', 'features.4.1.bn2.bias', 'features.4.1.bn2.running_mean', 'features.4.1.bn2.running_var', 'features.4.1.bn2.num_batches_tracked', 'features.5.0.conv1.weight', 'features.5.0.bn1.weight', 'features.5.0.bn1.bias', 'features.5.0.bn1.running_mean', 'features.5.0.bn1.running_var', 'features.5.0.bn1.num_batches_tracked', 'features.5.0.conv2.weight', 'features.5.0.bn2.weight', 'features.5.0.bn2.bias', 'features.5.0.bn2.running_mean', 'features.5.0.bn2.running_var', 'features.5.0.bn2.num_batches_tracked', 'features.5.0.downsample.0.weight', 'features.5.0.downsample.1.weight', 'features.5.0.downsample.1.bias', 'features.5.0.downsample.1.running_mean', 'features.5.0.downsample.1.running_var', 'features.5.0.downsample.1.num_batches_tracked', 'features.5.1.conv1.weight', 'features.5.1.bn1.weight', 'features.5.1.bn1.bias', 'features.5.1.bn1.running_mean', 'features.5.1.bn1.running_var', 'features.5.1.bn1.num_batches_tracked', 'features.5.1.conv2.weight', 'features.5.1.bn2.weight', 'features.5.1.bn2.bias', 'features.5.1.bn2.running_mean', 'features.5.1.bn2.running_var', 'features.5.1.bn2.num_batches_tracked', 'features.6.0.conv1.weight', 'features.6.0.bn1.weight', 'features.6.0.bn1.bias', 'features.6.0.bn1.running_mean', 'features.6.0.bn1.running_var', 'features.6.0.bn1.num_batches_tracked', 'features.6.0.conv2.weight', 'features.6.0.bn2.weight', 'features.6.0.bn2.bias', 'features.6.0.bn2.running_mean', 'features.6.0.bn2.running_var', 'features.6.0.bn2.num_batches_tracked', 'features.6.0.downsample.0.weight', 'features.6.0.downsample.1.weight', 'features.6.0.downsample.1.bias', 'features.6.0.downsample.1.running_mean', 'features.6.0.downsample.1.running_var', 'features.6.0.downsample.1.num_batches_tracked', 'features.6.1.conv1.weight', 'features.6.1.bn1.weight', 'features.6.1.bn1.bias', 'features.6.1.bn1.running_mean', 'features.6.1.bn1.running_var', 'features.6.1.bn1.num_batches_tracked', 'features.6.1.conv2.weight', 'features.6.1.bn2.weight', 'features.6.1.bn2.bias', 'features.6.1.bn2.running_mean', 'features.6.1.bn2.running_var', 'features.6.1.bn2.num_batches_tracked', 'features.7.0.conv1.weight', 'features.7.0.bn1.weight', 'features.7.0.bn1.bias', 'features.7.0.bn1.running_mean', 'features.7.0.bn1.running_var', 'features.7.0.bn1.num_batches_tracked', 'features.7.0.conv2.weight', 'features.7.0.bn2.weight', 'features.7.0.bn2.bias', 'features.7.0.bn2.running_mean', 'features.7.0.bn2.running_var', 'features.7.0.bn2.num_batches_tracked', 'features.7.0.downsample.0.weight', 'features.7.0.downsample.1.weight', 'features.7.0.downsample.1.bias', 'features.7.0.downsample.1.running_mean', 'features.7.0.downsample.1.running_var', 'features.7.0.downsample.1.num_batches_tracked', 'features.7.1.conv1.weight', 'features.7.1.bn1.weight', 'features.7.1.bn1.bias', 'features.7.1.bn1.running_mean', 'features.7.1.bn1.running_var', 'features.7.1.bn1.num_batches_tracked', 'features.7.1.conv2.weight', 'features.7.1.bn2.weight', 'features.7.1.bn2.bias', 'features.7.1.bn2.running_mean', 'features.7.1.bn2.running_var', 'features.7.1.bn2.num_batches_tracked', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])\n"
]
@ -306,7 +306,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 12,
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
@ -331,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 29,
"metadata": {},
"outputs": [
{
@ -358,7 +358,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 30,
"metadata": {},
"outputs": [
{
@ -385,7 +385,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 31,
"metadata": {},
"outputs": [
{
@ -398,7 +398,7 @@
" warm_start=False)"
]
},
"execution_count": 15,
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
@ -414,7 +414,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 32,
"metadata": {},
"outputs": [
{
@ -422,8 +422,8 @@
"output_type": "stream",
"text": [
"SimCLR feature evaluation\n",
"Train score: 0.8948\n",
"Test score: 0.639625\n"
"Train score: 0.8914\n",
"Test score: 0.6425\n"
]
}
],

View File

@ -8,9 +8,10 @@ from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import matplotlib.pyplot as plt
from models.resnet_simclr import ResNetSimCLR
from utils import get_negative_mask, get_augmentation_transform, get_similarity_function
from utils import get_negative_mask, get_similarity_function
from data_aug.data_transform import DataTransform, get_data_transform_opes
torch.manual_seed(0)
@ -21,9 +22,9 @@ out_dim = config['out_dim']
temperature = config['temperature']
use_cosine_similarity = config['use_cosine_similarity']
data_augment = get_augmentation_transform(s=config['s'], crop_size=96)
data_augment = get_data_transform_opes(s=config['s'], crop_size=96)
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=transforms.ToTensor())
train_dataset = datasets.STL10('./data', split='train', download=True, transform=DataTransform(data_augment))
# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(),
# target_transform=None, download=True)
@ -52,20 +53,7 @@ negative_mask = get_negative_mask(batch_size)
n_iter = 0
for e in range(config['epochs']):
for step, (batch_x, _) in enumerate(train_loader):
optimizer.zero_grad()
xis = []
xjs = []
# draw two augmentation functions t , t' and apply separately for each input example
for k in range(len(batch_x)):
xis.append(data_augment(batch_x[k])) # the first augmentation
xjs.append(data_augment(batch_x[k])) # the second augmentation
xis = torch.stack(xis)
xjs = torch.stack(xjs)
for step, ((xis, xjs), _) in enumerate(train_loader):
if train_gpu:
xis = xis.cuda()

View File

@ -2,6 +2,7 @@ import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
np.random.seed(0)
cos1d = torch.nn.CosineSimilarity(dim=1)
@ -19,40 +20,6 @@ def get_negative_mask(batch_size):
return negative_mask
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
def __call__(self, sample):
sample = np.array(sample)
# blur the image with a 50% chance
prob = np.random.random_sample()
if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
return sample
def get_augmentation_transform(s, crop_size):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
transforms.RandomResizedCrop(size=crop_size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * crop_size)),
transforms.ToTensor()])
return data_aug_ope
def _dot_simililarity_dim1(x, y):
# x shape: (N, 1, C)
# y shape: (N, C, 1)