mirror of https://github.com/sthalles/SimCLR.git
first commit
commit
8d28ca5d31
|
@ -0,0 +1,42 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, out_dim=64):
|
||||
super(Encoder, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
|
||||
# projection MLP
|
||||
self.l1 = nn.Linear(64, 64)
|
||||
self.l2 = nn.Linear(64, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.conv3(x)
|
||||
x = F.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.conv4(x)
|
||||
x = F.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
h = torch.mean(x, dim=[2, 3])
|
||||
|
||||
x = self.l1(h)
|
||||
x = F.relu(x)
|
||||
x = self.l2(x)
|
||||
|
||||
return h, x
|
Binary file not shown.
|
@ -0,0 +1,97 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets
|
||||
|
||||
from model import Encoder
|
||||
from utils import GaussianBlur
|
||||
|
||||
batch_size = 32
|
||||
out_dim = 64
|
||||
s = 1
|
||||
|
||||
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
|
||||
|
||||
data_augment = transforms.Compose([transforms.ToPILImage(),
|
||||
transforms.RandomResizedCrop(96),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([color_jitter], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
GaussianBlur(),
|
||||
transforms.ToTensor()])
|
||||
|
||||
train_dataset = datasets.STL10('data', download=True, transform=transforms.ToTensor())
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True)
|
||||
|
||||
model = Encoder(out_dim=out_dim)
|
||||
print(model)
|
||||
|
||||
train_gpu = False ## torch.cuda.is_available()
|
||||
print("Is gpu available:", train_gpu)
|
||||
# moves the model paramemeters to gpu
|
||||
if train_gpu:
|
||||
model.cuda()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), 3e-4)
|
||||
|
||||
for e in range(20):
|
||||
for step, (batch_x, _) in enumerate(train_loader):
|
||||
# print("Input batch:", batch_x.shape, torch.min(batch_x), torch.max(batch_x))
|
||||
optimizer.zero_grad()
|
||||
|
||||
xis = []
|
||||
xjs = []
|
||||
for k in range(len(batch_x)):
|
||||
xis.append(data_augment(batch_x[k]))
|
||||
xjs.append(data_augment(batch_x[k]))
|
||||
|
||||
# fig, axs = plt.subplots(nrows=1, ncols=6, constrained_layout=False)
|
||||
# fig, axs = plt.subplots(nrows=3, ncols=2, constrained_layout=False)
|
||||
# for i_ in range(3):
|
||||
# axs[i_, 0].imshow(xis[i_].permute(1, 2, 0))
|
||||
# axs[i_, 1].imshow(xjs[i_].permute(1, 2, 0))
|
||||
# plt.show()
|
||||
|
||||
xis = torch.stack(xis)
|
||||
xjs = torch.stack(xjs)
|
||||
# print("Transformed input stats:", torch.min(xis), torch.max(xjs))
|
||||
|
||||
_, zis = model(xis) # [N,C]
|
||||
# print(his.shape, zis.shape)
|
||||
|
||||
_, zjs = model(xjs) # [N,C]
|
||||
# print(hjs.shape, zjs.shape)
|
||||
|
||||
# positive pairs
|
||||
l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
|
||||
assert l_pos.shape == (batch_size, 1) # [N,1]
|
||||
l_neg = []
|
||||
|
||||
for i in range(zis.shape[0]):
|
||||
mask = np.ones(zjs.shape[0], dtype=bool)
|
||||
mask[i] = False
|
||||
negs = torch.cat([zjs[mask], zis[mask]], dim=0) # [2*(N-1), C]
|
||||
l_neg.append(torch.mm(zis[i].view(1, zis.shape[-1]), negs.permute(1, 0)))
|
||||
|
||||
l_neg = torch.cat(l_neg) # [N, 2*(N-1)]
|
||||
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
|
||||
# print("l_neg.shape -->", l_neg.shape)
|
||||
|
||||
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
||||
# print("logits.shape -->",logits.shape)
|
||||
|
||||
labels = torch.zeros(batch_size, dtype=torch.long)
|
||||
|
||||
if train_gpu:
|
||||
labels = labels.cuda()
|
||||
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print("Step {}, Loss {}".format(step, loss))
|
||||
|
||||
torch.save(model.state_dict(), './model/checkpoint.pth')
|
|
@ -0,0 +1,26 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
||||
def __init__(self, min=0.1, max=2.0, kernel_size=9):
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def __call__(self, sample):
|
||||
sample = np.array(sample)
|
||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||
return sample
|
||||
|
||||
# class ToTensor(object):
|
||||
# """Convert ndarrays in sample to Tensors."""
|
||||
#
|
||||
# def __call__(self, sample):
|
||||
# # swap color axis because
|
||||
# # numpy image: H x W x C
|
||||
# # torch image: C X H X W
|
||||
# sample = sample.transpose((2, 0, 1))
|
||||
# return torch.tensor(sample)
|
Loading…
Reference in New Issue