SimCLR/train.py

113 lines
3.9 KiB
Python
Raw Normal View History

2020-02-18 03:05:44 +08:00
import numpy as np
import torch
2020-02-18 10:17:10 +08:00
print(torch.__version__)
2020-02-18 03:05:44 +08:00
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
2020-02-18 10:17:10 +08:00
from torch.utils.tensorboard import SummaryWriter
2020-02-18 03:05:44 +08:00
from model import Encoder
from utils import GaussianBlur
2020-02-18 10:17:10 +08:00
batch_size = 64
2020-02-18 03:05:44 +08:00
out_dim = 64
s = 1
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
data_augment = transforms.Compose([transforms.ToPILImage(),
transforms.RandomResizedCrop(96),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(),
transforms.ToTensor()])
2020-02-18 10:17:10 +08:00
train_dataset = datasets.STL10('data', split='train+unlabeled', download=True, transform=transforms.ToTensor())
2020-02-18 03:05:44 +08:00
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True)
model = Encoder(out_dim=out_dim)
print(model)
2020-02-18 10:17:10 +08:00
train_gpu = torch.cuda.is_available()
2020-02-18 03:05:44 +08:00
print("Is gpu available:", train_gpu)
# moves the model paramemeters to gpu
if train_gpu:
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), 3e-4)
2020-02-18 10:17:10 +08:00
train_writer = SummaryWriter()
n_iter = 0
for e in range(40):
2020-02-18 03:05:44 +08:00
for step, (batch_x, _) in enumerate(train_loader):
# print("Input batch:", batch_x.shape, torch.min(batch_x), torch.max(batch_x))
optimizer.zero_grad()
xis = []
xjs = []
for k in range(len(batch_x)):
xis.append(data_augment(batch_x[k]))
xjs.append(data_augment(batch_x[k]))
# fig, axs = plt.subplots(nrows=1, ncols=6, constrained_layout=False)
# fig, axs = plt.subplots(nrows=3, ncols=2, constrained_layout=False)
# for i_ in range(3):
# axs[i_, 0].imshow(xis[i_].permute(1, 2, 0))
# axs[i_, 1].imshow(xjs[i_].permute(1, 2, 0))
# plt.show()
xis = torch.stack(xis)
xjs = torch.stack(xjs)
2020-02-18 10:17:10 +08:00
if train_gpu:
xis = xis.cuda()
xjs = xjs.cuda()
2020-02-18 03:05:44 +08:00
# print("Transformed input stats:", torch.min(xis), torch.max(xjs))
2020-02-18 10:17:10 +08:00
ris, zis = model(xis) # [N,C]
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
2020-02-18 03:05:44 +08:00
# print(his.shape, zis.shape)
2020-02-18 10:17:10 +08:00
rjs, zjs = model(xjs) # [N,C]
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
2020-02-18 03:05:44 +08:00
# print(hjs.shape, zjs.shape)
# positive pairs
l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
assert l_pos.shape == (batch_size, 1) # [N,1]
l_neg = []
for i in range(zis.shape[0]):
mask = np.ones(zjs.shape[0], dtype=bool)
mask[i] = False
negs = torch.cat([zjs[mask], zis[mask]], dim=0) # [2*(N-1), C]
l_neg.append(torch.mm(zis[i].view(1, zis.shape[-1]), negs.permute(1, 0)))
l_neg = torch.cat(l_neg) # [N, 2*(N-1)]
assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(l_neg.shape)
# print("l_neg.shape -->", l_neg.shape)
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
# print("logits.shape -->",logits.shape)
labels = torch.zeros(batch_size, dtype=torch.long)
if train_gpu:
labels = labels.cuda()
loss = criterion(logits, labels)
2020-02-18 10:17:10 +08:00
train_writer.add_scalar('loss', loss, global_step=n_iter)
2020-02-18 03:05:44 +08:00
loss.backward()
optimizer.step()
2020-02-18 10:17:10 +08:00
n_iter += 1
# print("Step {}, Loss {}".format(step, loss))
2020-02-18 03:05:44 +08:00
torch.save(model.state_dict(), './model/checkpoint.pth')