mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added tensorboard support
This commit is contained in:
parent
304376b6a8
commit
5ed8a2e453
@ -7,7 +7,17 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from model import Encoder, ResNet18\n",
|
||||
"import sys\n",
|
||||
"sys.path.insert(1, '../')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from models.resnet_simclr import ResNetSimCLR\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from torchvision import datasets\n",
|
||||
@ -16,7 +26,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -26,54 +36,60 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _load_stl10(prefix=\"train\"):\n",
|
||||
" X_train = np.fromfile('../data/stl10_binary/' + prefix + '_X.bin', dtype=np.uint8)\n",
|
||||
" y_train = np.fromfile('../data/stl10_binary/' + prefix + '_y.bin', dtype=np.uint8)\n",
|
||||
"\n",
|
||||
" X_train = np.reshape(X_train, (-1, 3, 96, 96))\n",
|
||||
" X_train = np.transpose(X_train, (0, 3, 2, 1))\n",
|
||||
" print(\"{} images\".format(prefix))\n",
|
||||
" print(X_train.shape)\n",
|
||||
" print(y_train.shape)\n",
|
||||
" return X_train, y_train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training images\n",
|
||||
"train images\n",
|
||||
"(5000, 96, 96, 3)\n",
|
||||
"(5000,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X_train = np.fromfile('data/stl10_binary/train_X.bin', dtype=np.uint8)\n",
|
||||
"y_train = np.fromfile('data/stl10_binary/train_y.bin', dtype=np.uint8)\n",
|
||||
"\n",
|
||||
"X_train = np.reshape(X_train, (-1, 3, 96, 96))\n",
|
||||
"X_train = np.transpose(X_train, (0, 3, 2, 1))\n",
|
||||
"print(\"Training images\")\n",
|
||||
"print(X_train.shape)\n",
|
||||
"print(y_train.shape)"
|
||||
"# load STL-10 train data\n",
|
||||
"X_train, y_train = _load_stl10(\"train\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test images\n",
|
||||
"test images\n",
|
||||
"(8000, 96, 96, 3)\n",
|
||||
"(8000,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X_test = np.fromfile('data/stl10_binary/test_X.bin', dtype=np.uint8)\n",
|
||||
"y_test = np.fromfile('data/stl10_binary/test_y.bin', dtype=np.uint8)\n",
|
||||
"\n",
|
||||
"X_test = np.reshape(X_test, (-1, 3, 96, 96))\n",
|
||||
"X_test = np.transpose(X_test, (0, 3, 2, 1))\n",
|
||||
"print(\"Test images\")\n",
|
||||
"print(X_test.shape)\n",
|
||||
"print(y_test.shape)"
|
||||
"# load STL-10 test data\n",
|
||||
"X_test, y_test = _load_stl10(\"test\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -85,17 +101,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.decomposition import PCA\n",
|
||||
"from sklearn.linear_model import LogisticRegression"
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"from sklearn import preprocessing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -109,9 +126,13 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pca = PCA(n_components=out_dim)\n",
|
||||
"X_train_pca = pca.fit_transform(X_train.reshape((X_train.shape[0],-1)))\n",
|
||||
"X_test_pca = pca.transform(X_test.reshape((X_test.shape[0],-1)))\n",
|
||||
"scaler = preprocessing.StandardScaler()\n",
|
||||
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
|
||||
"\n",
|
||||
"pca = PCA(n_components=64)\n",
|
||||
"\n",
|
||||
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
|
||||
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
|
||||
"\n",
|
||||
"print(\"PCA features\")\n",
|
||||
"print(X_train_pca.shape)\n",
|
||||
@ -120,7 +141,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -128,27 +149,14 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"PCA feature evaluation\n",
|
||||
"Train score: 0.3984\n",
|
||||
"Test score: 0.353125\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||||
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||||
"\n",
|
||||
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||||
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||||
"Please also refer to the documentation for alternative solver options:\n",
|
||||
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||||
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
|
||||
"Train score: 0.396\n",
|
||||
"Test score: 0.3565\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"clf = LogisticRegression(random_state=0).fit(X_train_pca, y_train)\n",
|
||||
"clf = LogisticRegression(random_state=0, max_iter=1000, solver='lbfgs', C=1.0)\n",
|
||||
"clf.fit(X_train_pca, y_train)\n",
|
||||
"print(\"PCA feature evaluation\")\n",
|
||||
"print(\"Train score:\", clf.score(X_train_pca, y_train))\n",
|
||||
"print(\"Test score:\", clf.score(X_test_pca, y_test))"
|
||||
@ -156,7 +164,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -171,13 +179,13 @@
|
||||
"data_augment = transforms.Compose([transforms.RandomResizedCrop(96),\n",
|
||||
" transforms.ToTensor()])\n",
|
||||
"\n",
|
||||
"train_dataset = datasets.STL10('data', split='train', download=True, transform=data_augment)\n",
|
||||
"train_dataset = datasets.STL10('../data', split='train', download=True, transform=data_augment)\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=False, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -189,20 +197,20 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"test_dataset = datasets.STL10('data', split='test', download=True, transform=data_augment)\n",
|
||||
"test_dataset = datasets.STL10('../data', split='test', download=True, transform=data_augment)\n",
|
||||
"test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=1, drop_last=False, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ResNet18(\n",
|
||||
"ResNetSimCLR(\n",
|
||||
" (features): Sequential(\n",
|
||||
" (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||||
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
@ -298,17 +306,17 @@
|
||||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = ResNet18(out_dim=out_dim)\n",
|
||||
"model = ResNetSimCLR(out_dim=out_dim)\n",
|
||||
"model.eval()\n",
|
||||
"print(model)\n",
|
||||
"\n",
|
||||
"state_dict = torch.load('model/checkpoint.pth')\n",
|
||||
"state_dict = torch.load('../checkpoints/checkpoint.pth')\n",
|
||||
"print(state_dict.keys())\n",
|
||||
"\n",
|
||||
"model.load_state_dict(state_dict)"
|
||||
@ -323,7 +331,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -350,7 +358,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -377,32 +385,36 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||||
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||||
"\n",
|
||||
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||||
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||||
"Please also refer to the documentation for alternative solver options:\n",
|
||||
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||||
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
|
||||
]
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
|
||||
" intercept_scaling=1, l1_ratio=None, max_iter=1000,\n",
|
||||
" multi_class='auto', n_jobs=None, penalty='l2',\n",
|
||||
" random_state=0, solver='lbfgs', tol=0.0001, verbose=0,\n",
|
||||
" warm_start=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"clf = LogisticRegression(random_state=0).fit(X_train_feature, y_train)"
|
||||
"clf = LogisticRegression(random_state=0, max_iter=1000, solver='lbfgs', C=1.0)\n",
|
||||
"\n",
|
||||
"scaler = preprocessing.StandardScaler()\n",
|
||||
"scaler.fit(X_train_feature)\n",
|
||||
"\n",
|
||||
"clf.fit(scaler.transform(X_train_feature), y_train)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -410,19 +422,30 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"SimCLR feature evaluation\n",
|
||||
"Train score: 0.7852\n",
|
||||
"Test score: 0.641625\n"
|
||||
"Train score: 0.8948\n",
|
||||
"Test score: 0.639625\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"SimCLR feature evaluation\")\n",
|
||||
"print(\"Train score:\", clf.score(X_train_feature, y_train))\n",
|
||||
"print(\"Test score:\", clf.score(X_test_feature, y_test))\n",
|
||||
"# SimCLR feature evaluation\n",
|
||||
"# Train score: 0.7852\n",
|
||||
"# Test score: 0.641625"
|
||||
"print(\"Train score:\", clf.score(scaler.transform(X_train_feature), y_train))\n",
|
||||
"print(\"Test score:\", clf.score(scaler.transform(X_test_feature), y_test))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
Binary file not shown.
@ -41,35 +41,3 @@ class Encoder(nn.Module):
|
||||
x = self.l2(x)
|
||||
|
||||
return h, x
|
||||
|
||||
|
||||
class ResNetSimCLR(nn.Module):
|
||||
|
||||
def __init__(self, base_model="resnet18", out_dim=64):
|
||||
super(ResNetSimCLR, self).__init__()
|
||||
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
|
||||
"resnet50": models.resnet50(pretrained=False)}
|
||||
|
||||
resnet = self._get_basemodel(base_model)
|
||||
num_ftrs = resnet.fc.in_features
|
||||
|
||||
self.features = nn.Sequential(*list(resnet.children())[:-1])
|
||||
|
||||
# projection MLP
|
||||
self.l1 = nn.Linear(num_ftrs, num_ftrs)
|
||||
self.l2 = nn.Linear(num_ftrs, out_dim)
|
||||
|
||||
def _get_basemodel(self, model_name):
|
||||
try:
|
||||
return self.resnet_dict[model_name]
|
||||
except:
|
||||
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
||||
|
||||
def forward(self, x):
|
||||
h = self.features(x)
|
||||
h = h.squeeze()
|
||||
|
||||
x = self.l1(h)
|
||||
x = F.relu(x)
|
||||
x = self.l2(x)
|
||||
return h, x
|
36
models/resnet_simclr.py
Normal file
36
models/resnet_simclr.py
Normal file
@ -0,0 +1,36 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class ResNetSimCLR(nn.Module):
|
||||
|
||||
def __init__(self, base_model="resnet18", out_dim=64):
|
||||
super(ResNetSimCLR, self).__init__()
|
||||
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
|
||||
"resnet50": models.resnet50(pretrained=False)}
|
||||
|
||||
resnet = self._get_basemodel(base_model)
|
||||
num_ftrs = resnet.fc.in_features
|
||||
|
||||
self.features = nn.Sequential(*list(resnet.children())[:-1])
|
||||
|
||||
# projection MLP
|
||||
self.l1 = nn.Linear(num_ftrs, num_ftrs)
|
||||
self.l2 = nn.Linear(num_ftrs, out_dim)
|
||||
|
||||
def _get_basemodel(self, model_name):
|
||||
try:
|
||||
return self.resnet_dict[model_name]
|
||||
except:
|
||||
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
||||
|
||||
def forward(self, x):
|
||||
h = self.features(x)
|
||||
h = h.squeeze()
|
||||
|
||||
x = self.l1(h)
|
||||
x = F.relu(x)
|
||||
x = self.l2(x)
|
||||
return h, x
|
21
train.py
21
train.py
@ -9,8 +9,9 @@ from torchvision import datasets
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import ResNetSimCLR
|
||||
from utils import GaussianBlur
|
||||
from models.baseline_encoder import Encoder
|
||||
from models.resnet_simclr import ResNetSimCLR
|
||||
from utils import GaussianBlur, get_negative_mask
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@ -52,14 +53,11 @@ optimizer = optim.Adam(model.parameters(), 3e-4)
|
||||
train_writer = SummaryWriter()
|
||||
|
||||
if use_cosine_similarity:
|
||||
similarity_dim1 = torch.nn.CosineSimilarity(dim=1)
|
||||
similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
|
||||
cos_similarity_dim1 = torch.nn.CosineSimilarity(dim=1)
|
||||
cos_similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
|
||||
|
||||
# Mask to remove positive examples from the batch of negative samples
|
||||
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
|
||||
for i in range(batch_size):
|
||||
negative_mask[i, i] = 0
|
||||
negative_mask[i, i + batch_size] = 0
|
||||
negative_mask = get_negative_mask(batch_size)
|
||||
|
||||
n_iter = 0
|
||||
for e in range(config['epochs']):
|
||||
@ -96,7 +94,8 @@ for e in range(config['epochs']):
|
||||
|
||||
# positive pairs
|
||||
if use_cosine_similarity:
|
||||
l_pos = similarity_dim1(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size, 1)
|
||||
l_pos = cos_similarity_dim1(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
|
||||
1)
|
||||
else:
|
||||
l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
|
||||
|
||||
@ -111,7 +110,7 @@ for e in range(config['epochs']):
|
||||
|
||||
if use_cosine_similarity:
|
||||
negatives = negatives.view(1, (2 * batch_size), out_dim)
|
||||
l_neg = similarity_dim2(positives.view(batch_size, 1, out_dim), negatives)
|
||||
l_neg = cos_similarity_dim2(positives.view(batch_size, 1, out_dim), negatives)
|
||||
else:
|
||||
l_neg = torch.tensordot(positives.view(batch_size, 1, out_dim),
|
||||
negatives.T.view(1, out_dim, (2 * batch_size)),
|
||||
@ -137,4 +136,4 @@ for e in range(config['epochs']):
|
||||
n_iter += 1
|
||||
# print("Step {}, Loss {}".format(step, loss))
|
||||
|
||||
torch.save(model.state_dict(), './model/checkpoint.pth')
|
||||
torch.save(model.state_dict(), './checkpoints/checkpoint.pth')
|
||||
|
9
utils.py
9
utils.py
@ -1,9 +1,18 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def get_negative_mask(batch_size):
|
||||
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
|
||||
for i in range(batch_size):
|
||||
negative_mask[i, i] = 0
|
||||
negative_mask[i, i + batch_size] = 0
|
||||
return negative_mask
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
||||
def __init__(self, min=0.1, max=2.0, kernel_size=9):
|
||||
|
Loading…
x
Reference in New Issue
Block a user