support to use other pytorch datasets

pull/5/head
Thalles 2020-02-29 19:32:37 -03:00
parent 2b6bfd9933
commit 88dcdf6d06
4 changed files with 80 additions and 79 deletions

View File

@ -0,0 +1,46 @@
import torchvision.transforms as transforms
import cv2
import numpy as np
class DataTransform(object):
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
def __call__(self, sample):
sample = np.array(sample)
# blur the image with a 50% chance
prob = np.random.random_sample()
if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
return sample
def get_data_transform_opes(s, crop_size):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=crop_size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * crop_size)),
transforms.ToTensor()])
return data_transforms

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -13,7 +13,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -26,17 +26,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"batch_size = 256\n", "batch_size = 256\n",
"out_dim = 64" "out_dim = 128"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -54,7 +54,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -74,7 +74,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -101,7 +101,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -112,7 +112,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -120,8 +120,8 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"PCA features\n", "PCA features\n",
"(5000, 64)\n", "(5000, 128)\n",
"(8000, 64)\n" "(8000, 128)\n"
] ]
} }
], ],
@ -129,7 +129,7 @@
"scaler = preprocessing.StandardScaler()\n", "scaler = preprocessing.StandardScaler()\n",
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n", "scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
"\n", "\n",
"pca = PCA(n_components=64)\n", "pca = PCA(n_components=128)\n",
"\n", "\n",
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n", "X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n", "X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
@ -141,7 +141,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 25,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -149,8 +149,8 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"PCA feature evaluation\n", "PCA feature evaluation\n",
"Train score: 0.396\n", "Train score: 0.4306\n",
"Test score: 0.3565\n" "Test score: 0.3625\n"
] ]
} }
], ],
@ -164,7 +164,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 26,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -185,7 +185,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 27,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -203,7 +203,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -295,7 +295,7 @@
" (8): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" )\n", " )\n",
" (l1): Linear(in_features=512, out_features=512, bias=True)\n", " (l1): Linear(in_features=512, out_features=512, bias=True)\n",
" (l2): Linear(in_features=512, out_features=64, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n",
")\n", ")\n",
"odict_keys(['features.0.weight', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.4.0.conv1.weight', 'features.4.0.bn1.weight', 'features.4.0.bn1.bias', 'features.4.0.bn1.running_mean', 'features.4.0.bn1.running_var', 'features.4.0.bn1.num_batches_tracked', 'features.4.0.conv2.weight', 'features.4.0.bn2.weight', 'features.4.0.bn2.bias', 'features.4.0.bn2.running_mean', 'features.4.0.bn2.running_var', 'features.4.0.bn2.num_batches_tracked', 'features.4.1.conv1.weight', 'features.4.1.bn1.weight', 'features.4.1.bn1.bias', 'features.4.1.bn1.running_mean', 'features.4.1.bn1.running_var', 'features.4.1.bn1.num_batches_tracked', 'features.4.1.conv2.weight', 'features.4.1.bn2.weight', 'features.4.1.bn2.bias', 'features.4.1.bn2.running_mean', 'features.4.1.bn2.running_var', 'features.4.1.bn2.num_batches_tracked', 'features.5.0.conv1.weight', 'features.5.0.bn1.weight', 'features.5.0.bn1.bias', 'features.5.0.bn1.running_mean', 'features.5.0.bn1.running_var', 'features.5.0.bn1.num_batches_tracked', 'features.5.0.conv2.weight', 'features.5.0.bn2.weight', 'features.5.0.bn2.bias', 'features.5.0.bn2.running_mean', 'features.5.0.bn2.running_var', 'features.5.0.bn2.num_batches_tracked', 'features.5.0.downsample.0.weight', 'features.5.0.downsample.1.weight', 'features.5.0.downsample.1.bias', 'features.5.0.downsample.1.running_mean', 'features.5.0.downsample.1.running_var', 'features.5.0.downsample.1.num_batches_tracked', 'features.5.1.conv1.weight', 'features.5.1.bn1.weight', 'features.5.1.bn1.bias', 'features.5.1.bn1.running_mean', 'features.5.1.bn1.running_var', 'features.5.1.bn1.num_batches_tracked', 'features.5.1.conv2.weight', 'features.5.1.bn2.weight', 'features.5.1.bn2.bias', 'features.5.1.bn2.running_mean', 'features.5.1.bn2.running_var', 'features.5.1.bn2.num_batches_tracked', 'features.6.0.conv1.weight', 'features.6.0.bn1.weight', 'features.6.0.bn1.bias', 'features.6.0.bn1.running_mean', 'features.6.0.bn1.running_var', 'features.6.0.bn1.num_batches_tracked', 'features.6.0.conv2.weight', 'features.6.0.bn2.weight', 'features.6.0.bn2.bias', 'features.6.0.bn2.running_mean', 'features.6.0.bn2.running_var', 'features.6.0.bn2.num_batches_tracked', 'features.6.0.downsample.0.weight', 'features.6.0.downsample.1.weight', 'features.6.0.downsample.1.bias', 'features.6.0.downsample.1.running_mean', 'features.6.0.downsample.1.running_var', 'features.6.0.downsample.1.num_batches_tracked', 'features.6.1.conv1.weight', 'features.6.1.bn1.weight', 'features.6.1.bn1.bias', 'features.6.1.bn1.running_mean', 'features.6.1.bn1.running_var', 'features.6.1.bn1.num_batches_tracked', 'features.6.1.conv2.weight', 'features.6.1.bn2.weight', 'features.6.1.bn2.bias', 'features.6.1.bn2.running_mean', 'features.6.1.bn2.running_var', 'features.6.1.bn2.num_batches_tracked', 'features.7.0.conv1.weight', 'features.7.0.bn1.weight', 'features.7.0.bn1.bias', 'features.7.0.bn1.running_mean', 'features.7.0.bn1.running_var', 'features.7.0.bn1.num_batches_tracked', 'features.7.0.conv2.weight', 'features.7.0.bn2.weight', 'features.7.0.bn2.bias', 'features.7.0.bn2.running_mean', 'features.7.0.bn2.running_var', 'features.7.0.bn2.num_batches_tracked', 'features.7.0.downsample.0.weight', 'features.7.0.downsample.1.weight', 'features.7.0.downsample.1.bias', 'features.7.0.downsample.1.running_mean', 'features.7.0.downsample.1.running_var', 'features.7.0.downsample.1.num_batches_tracked', 'features.7.1.conv1.weight', 'features.7.1.bn1.weight', 'features.7.1.bn1.bias', 'features.7.1.bn1.running_mean', 'features.7.1.bn1.running_var', 'features.7.1.bn1.num_batches_tracked', 'features.7.1.conv2.weight', 'features.7.1.bn2.weight', 'features.7.1.bn2.bias', 'features.7.1.bn2.running_mean', 'features.7.1.bn2.running_var', 'features.7.1.bn2.num_batches_tracked', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])\n" "odict_keys(['features.0.weight', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.4.0.conv1.weight', 'features.4.0.bn1.weight', 'features.4.0.bn1.bias', 'features.4.0.bn1.running_mean', 'features.4.0.bn1.running_var', 'features.4.0.bn1.num_batches_tracked', 'features.4.0.conv2.weight', 'features.4.0.bn2.weight', 'features.4.0.bn2.bias', 'features.4.0.bn2.running_mean', 'features.4.0.bn2.running_var', 'features.4.0.bn2.num_batches_tracked', 'features.4.1.conv1.weight', 'features.4.1.bn1.weight', 'features.4.1.bn1.bias', 'features.4.1.bn1.running_mean', 'features.4.1.bn1.running_var', 'features.4.1.bn1.num_batches_tracked', 'features.4.1.conv2.weight', 'features.4.1.bn2.weight', 'features.4.1.bn2.bias', 'features.4.1.bn2.running_mean', 'features.4.1.bn2.running_var', 'features.4.1.bn2.num_batches_tracked', 'features.5.0.conv1.weight', 'features.5.0.bn1.weight', 'features.5.0.bn1.bias', 'features.5.0.bn1.running_mean', 'features.5.0.bn1.running_var', 'features.5.0.bn1.num_batches_tracked', 'features.5.0.conv2.weight', 'features.5.0.bn2.weight', 'features.5.0.bn2.bias', 'features.5.0.bn2.running_mean', 'features.5.0.bn2.running_var', 'features.5.0.bn2.num_batches_tracked', 'features.5.0.downsample.0.weight', 'features.5.0.downsample.1.weight', 'features.5.0.downsample.1.bias', 'features.5.0.downsample.1.running_mean', 'features.5.0.downsample.1.running_var', 'features.5.0.downsample.1.num_batches_tracked', 'features.5.1.conv1.weight', 'features.5.1.bn1.weight', 'features.5.1.bn1.bias', 'features.5.1.bn1.running_mean', 'features.5.1.bn1.running_var', 'features.5.1.bn1.num_batches_tracked', 'features.5.1.conv2.weight', 'features.5.1.bn2.weight', 'features.5.1.bn2.bias', 'features.5.1.bn2.running_mean', 'features.5.1.bn2.running_var', 'features.5.1.bn2.num_batches_tracked', 'features.6.0.conv1.weight', 'features.6.0.bn1.weight', 'features.6.0.bn1.bias', 'features.6.0.bn1.running_mean', 'features.6.0.bn1.running_var', 'features.6.0.bn1.num_batches_tracked', 'features.6.0.conv2.weight', 'features.6.0.bn2.weight', 'features.6.0.bn2.bias', 'features.6.0.bn2.running_mean', 'features.6.0.bn2.running_var', 'features.6.0.bn2.num_batches_tracked', 'features.6.0.downsample.0.weight', 'features.6.0.downsample.1.weight', 'features.6.0.downsample.1.bias', 'features.6.0.downsample.1.running_mean', 'features.6.0.downsample.1.running_var', 'features.6.0.downsample.1.num_batches_tracked', 'features.6.1.conv1.weight', 'features.6.1.bn1.weight', 'features.6.1.bn1.bias', 'features.6.1.bn1.running_mean', 'features.6.1.bn1.running_var', 'features.6.1.bn1.num_batches_tracked', 'features.6.1.conv2.weight', 'features.6.1.bn2.weight', 'features.6.1.bn2.bias', 'features.6.1.bn2.running_mean', 'features.6.1.bn2.running_var', 'features.6.1.bn2.num_batches_tracked', 'features.7.0.conv1.weight', 'features.7.0.bn1.weight', 'features.7.0.bn1.bias', 'features.7.0.bn1.running_mean', 'features.7.0.bn1.running_var', 'features.7.0.bn1.num_batches_tracked', 'features.7.0.conv2.weight', 'features.7.0.bn2.weight', 'features.7.0.bn2.bias', 'features.7.0.bn2.running_mean', 'features.7.0.bn2.running_var', 'features.7.0.bn2.num_batches_tracked', 'features.7.0.downsample.0.weight', 'features.7.0.downsample.1.weight', 'features.7.0.downsample.1.bias', 'features.7.0.downsample.1.running_mean', 'features.7.0.downsample.1.running_var', 'features.7.0.downsample.1.num_batches_tracked', 'features.7.1.conv1.weight', 'features.7.1.bn1.weight', 'features.7.1.bn1.bias', 'features.7.1.bn1.running_mean', 'features.7.1.bn1.running_var', 'features.7.1.bn1.num_batches_tracked', 'features.7.1.conv2.weight', 'features.7.1.bn2.weight', 'features.7.1.bn2.bias', 'features.7.1.bn2.running_mean', 'features.7.1.bn2.running_var', 'features.7.1.bn2.num_batches_tracked', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])\n"
] ]
@ -306,7 +306,7 @@
"<All keys matched successfully>" "<All keys matched successfully>"
] ]
}, },
"execution_count": 12, "execution_count": 28,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -331,7 +331,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 29,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -358,7 +358,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 30,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -385,7 +385,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 31,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -398,7 +398,7 @@
" warm_start=False)" " warm_start=False)"
] ]
}, },
"execution_count": 15, "execution_count": 31,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -414,7 +414,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 32,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -422,8 +422,8 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"SimCLR feature evaluation\n", "SimCLR feature evaluation\n",
"Train score: 0.8948\n", "Train score: 0.8914\n",
"Test score: 0.639625\n" "Test score: 0.6425\n"
] ]
} }
], ],

View File

@ -8,9 +8,10 @@ from torch.utils.data import DataLoader
from torchvision import datasets from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F import torch.nn.functional as F
import matplotlib.pyplot as plt
from models.resnet_simclr import ResNetSimCLR from models.resnet_simclr import ResNetSimCLR
from utils import get_negative_mask, get_augmentation_transform, get_similarity_function from utils import get_negative_mask, get_similarity_function
from data_aug.data_transform import DataTransform, get_data_transform_opes
torch.manual_seed(0) torch.manual_seed(0)
@ -21,9 +22,9 @@ out_dim = config['out_dim']
temperature = config['temperature'] temperature = config['temperature']
use_cosine_similarity = config['use_cosine_similarity'] use_cosine_similarity = config['use_cosine_similarity']
data_augment = get_augmentation_transform(s=config['s'], crop_size=96) data_augment = get_data_transform_opes(s=config['s'], crop_size=96)
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=transforms.ToTensor()) train_dataset = datasets.STL10('./data', split='train', download=True, transform=DataTransform(data_augment))
# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(), # train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(),
# target_transform=None, download=True) # target_transform=None, download=True)
@ -52,20 +53,7 @@ negative_mask = get_negative_mask(batch_size)
n_iter = 0 n_iter = 0
for e in range(config['epochs']): for e in range(config['epochs']):
for step, (batch_x, _) in enumerate(train_loader): for step, ((xis, xjs), _) in enumerate(train_loader):
optimizer.zero_grad()
xis = []
xjs = []
# draw two augmentation functions t , t' and apply separately for each input example
for k in range(len(batch_x)):
xis.append(data_augment(batch_x[k])) # the first augmentation
xjs.append(data_augment(batch_x[k])) # the second augmentation
xis = torch.stack(xis)
xjs = torch.stack(xjs)
if train_gpu: if train_gpu:
xis = xis.cuda() xis = xis.cuda()

View File

@ -2,6 +2,7 @@ import cv2
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import matplotlib.pyplot as plt
np.random.seed(0) np.random.seed(0)
cos1d = torch.nn.CosineSimilarity(dim=1) cos1d = torch.nn.CosineSimilarity(dim=1)
@ -19,40 +20,6 @@ def get_negative_mask(batch_size):
return negative_mask return negative_mask
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
def __call__(self, sample):
sample = np.array(sample)
# blur the image with a 50% chance
prob = np.random.random_sample()
if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
return sample
def get_augmentation_transform(s, crop_size):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_aug_ope = transforms.Compose([transforms.ToPILImage(),
transforms.RandomResizedCrop(size=crop_size),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * crop_size)),
transforms.ToTensor()])
return data_aug_ope
def _dot_simililarity_dim1(x, y): def _dot_simililarity_dim1(x, y):
# x shape: (N, 1, C) # x shape: (N, 1, C)
# y shape: (N, C, 1) # y shape: (N, C, 1)