2020-02-18 03:05:44 +08:00
|
|
|
import torch
|
2020-02-20 10:27:27 +08:00
|
|
|
import yaml
|
2020-02-18 10:17:10 +08:00
|
|
|
|
|
|
|
print(torch.__version__)
|
2020-02-18 03:05:44 +08:00
|
|
|
import torch.optim as optim
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torchvision import datasets
|
2020-02-18 10:17:10 +08:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
2020-02-19 00:39:23 +08:00
|
|
|
import torch.nn.functional as F
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-25 05:23:44 +08:00
|
|
|
from models.resnet_simclr import ResNetSimCLR
|
2020-02-25 22:32:40 +08:00
|
|
|
from utils import get_negative_mask, get_augmentation_transform
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-20 21:02:03 +08:00
|
|
|
torch.manual_seed(0)
|
|
|
|
|
2020-02-20 10:27:27 +08:00
|
|
|
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
2020-02-20 06:12:07 +08:00
|
|
|
|
|
|
|
batch_size = config['batch_size']
|
|
|
|
out_dim = config['out_dim']
|
|
|
|
temperature = config['temperature']
|
|
|
|
use_cosine_similarity = config['use_cosine_similarity']
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-25 22:32:40 +08:00
|
|
|
data_augment = get_augmentation_transform(s=config['s'])
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-25 05:33:28 +08:00
|
|
|
train_dataset = datasets.STL10('./data', split='train', download=True, transform=transforms.ToTensor())
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=config['num_workers'], drop_last=True,
|
|
|
|
shuffle=True)
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-20 22:04:09 +08:00
|
|
|
# model = Encoder(out_dim=out_dim)
|
2020-02-25 02:36:10 +08:00
|
|
|
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-18 10:17:10 +08:00
|
|
|
train_gpu = torch.cuda.is_available()
|
2020-02-18 03:05:44 +08:00
|
|
|
print("Is gpu available:", train_gpu)
|
2020-02-20 22:04:09 +08:00
|
|
|
|
2020-02-18 03:05:44 +08:00
|
|
|
# moves the model paramemeters to gpu
|
|
|
|
if train_gpu:
|
|
|
|
model.cuda()
|
|
|
|
|
2020-02-20 10:27:27 +08:00
|
|
|
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
|
2020-02-18 03:05:44 +08:00
|
|
|
optimizer = optim.Adam(model.parameters(), 3e-4)
|
|
|
|
|
2020-02-18 10:17:10 +08:00
|
|
|
train_writer = SummaryWriter()
|
|
|
|
|
2020-02-19 03:52:22 +08:00
|
|
|
if use_cosine_similarity:
|
2020-02-25 05:23:44 +08:00
|
|
|
cos_similarity_dim1 = torch.nn.CosineSimilarity(dim=1)
|
|
|
|
cos_similarity_dim2 = torch.nn.CosineSimilarity(dim=2)
|
2020-02-19 02:56:48 +08:00
|
|
|
|
2020-02-20 06:12:07 +08:00
|
|
|
# Mask to remove positive examples from the batch of negative samples
|
2020-02-25 05:23:44 +08:00
|
|
|
negative_mask = get_negative_mask(batch_size)
|
2020-02-19 00:11:22 +08:00
|
|
|
|
2020-02-18 10:17:10 +08:00
|
|
|
n_iter = 0
|
2020-02-20 06:12:07 +08:00
|
|
|
for e in range(config['epochs']):
|
2020-02-18 03:05:44 +08:00
|
|
|
for step, (batch_x, _) in enumerate(train_loader):
|
2020-02-20 06:42:45 +08:00
|
|
|
|
2020-02-18 03:05:44 +08:00
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
xis = []
|
|
|
|
xjs = []
|
2020-02-25 22:32:40 +08:00
|
|
|
|
|
|
|
# draw two augmentation functions t , t' and apply separately for each input example
|
2020-02-18 03:05:44 +08:00
|
|
|
for k in range(len(batch_x)):
|
2020-02-25 22:32:40 +08:00
|
|
|
xis.append(data_augment(batch_x[k])) # the first augmentation
|
|
|
|
xjs.append(data_augment(batch_x[k])) # the second augmentation
|
2020-02-18 03:05:44 +08:00
|
|
|
|
|
|
|
xis = torch.stack(xis)
|
|
|
|
xjs = torch.stack(xjs)
|
2020-02-20 06:12:07 +08:00
|
|
|
|
2020-02-18 10:17:10 +08:00
|
|
|
if train_gpu:
|
|
|
|
xis = xis.cuda()
|
|
|
|
xjs = xjs.cuda()
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-25 22:32:40 +08:00
|
|
|
# get the representations and the projections
|
2020-02-18 10:17:10 +08:00
|
|
|
ris, zis = model(xis) # [N,C]
|
|
|
|
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
|
|
|
|
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-25 22:32:40 +08:00
|
|
|
# get the representations and the projections
|
2020-02-18 10:17:10 +08:00
|
|
|
rjs, zjs = model(xjs) # [N,C]
|
|
|
|
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
|
|
|
|
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-19 00:42:36 +08:00
|
|
|
# normalize projection feature vectors
|
2020-02-19 00:39:23 +08:00
|
|
|
zis = F.normalize(zis, dim=1)
|
|
|
|
zjs = F.normalize(zjs, dim=1)
|
2020-02-20 06:12:07 +08:00
|
|
|
# assert zis.shape == (batch_size, out_dim), "Shape not expected: " + str(zis.shape)
|
|
|
|
# assert zjs.shape == (batch_size, out_dim), "Shape not expected: " + str(zjs.shape)
|
2020-02-19 00:39:23 +08:00
|
|
|
|
2020-02-18 03:05:44 +08:00
|
|
|
# positive pairs
|
2020-02-19 03:52:22 +08:00
|
|
|
if use_cosine_similarity:
|
2020-02-25 05:23:44 +08:00
|
|
|
l_pos = cos_similarity_dim1(zis.view(batch_size, out_dim), zjs.view(batch_size, out_dim)).view(batch_size,
|
|
|
|
1)
|
2020-02-19 03:52:22 +08:00
|
|
|
else:
|
|
|
|
l_pos = torch.bmm(zis.view(batch_size, 1, out_dim), zjs.view(batch_size, out_dim, 1)).view(batch_size, 1)
|
|
|
|
|
|
|
|
l_pos /= temperature
|
2020-02-25 22:32:40 +08:00
|
|
|
# assert l_pos.shape == (batch_size, 1), "l_pos shape not valid" + str(l_pos.shape) # [N,1]
|
2020-02-19 00:11:22 +08:00
|
|
|
|
2020-02-20 06:12:07 +08:00
|
|
|
negatives = torch.cat([zjs, zis], dim=0)
|
2020-02-19 03:52:22 +08:00
|
|
|
|
2020-02-20 21:02:03 +08:00
|
|
|
loss = 0
|
2020-02-19 00:11:22 +08:00
|
|
|
|
2020-02-20 21:02:03 +08:00
|
|
|
for positives in [zis, zjs]:
|
|
|
|
|
|
|
|
if use_cosine_similarity:
|
|
|
|
negatives = negatives.view(1, (2 * batch_size), out_dim)
|
2020-02-25 05:23:44 +08:00
|
|
|
l_neg = cos_similarity_dim2(positives.view(batch_size, 1, out_dim), negatives)
|
2020-02-20 21:02:03 +08:00
|
|
|
else:
|
|
|
|
l_neg = torch.tensordot(positives.view(batch_size, 1, out_dim),
|
|
|
|
negatives.T.view(1, out_dim, (2 * batch_size)),
|
|
|
|
dims=2)
|
|
|
|
|
|
|
|
labels = torch.zeros(batch_size, dtype=torch.long)
|
|
|
|
if train_gpu:
|
|
|
|
labels = labels.cuda()
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-20 10:27:27 +08:00
|
|
|
l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
|
|
|
|
l_neg /= temperature
|
|
|
|
|
2020-02-25 05:34:11 +08:00
|
|
|
# assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(
|
|
|
|
# l_neg.shape)
|
2020-02-20 10:27:27 +08:00
|
|
|
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
|
|
|
|
loss += criterion(logits, labels)
|
|
|
|
|
|
|
|
loss = loss / (2 * batch_size)
|
2020-02-18 10:17:10 +08:00
|
|
|
train_writer.add_scalar('loss', loss, global_step=n_iter)
|
2020-02-18 03:05:44 +08:00
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
2020-02-18 10:17:10 +08:00
|
|
|
n_iter += 1
|
|
|
|
# print("Step {}, Loss {}".format(step, loss))
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-25 05:23:44 +08:00
|
|
|
torch.save(model.state_dict(), './checkpoints/checkpoint.pth')
|