SimCLR/train.py

112 lines
3.9 KiB
Python
Raw Normal View History

2020-02-18 03:05:44 +08:00
import torch
2020-02-20 10:27:27 +08:00
import yaml
2020-02-18 10:17:10 +08:00
print(torch.__version__)
2020-02-18 03:05:44 +08:00
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
2020-02-18 10:17:10 +08:00
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
2020-03-01 06:32:37 +08:00
import matplotlib.pyplot as plt
2020-02-25 05:23:44 +08:00
from models.resnet_simclr import ResNetSimCLR
2020-03-01 06:32:37 +08:00
from utils import get_negative_mask, get_similarity_function
from data_aug.data_transform import DataTransform, get_data_transform_opes
2020-02-18 03:05:44 +08:00
torch.manual_seed(0)
2020-02-20 10:27:27 +08:00
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
2020-02-20 06:12:07 +08:00
batch_size = config['batch_size']
out_dim = config['out_dim']
temperature = config['temperature']
use_cosine_similarity = config['use_cosine_similarity']
2020-02-18 03:05:44 +08:00
2020-03-01 06:32:37 +08:00
data_augment = get_data_transform_opes(s=config['s'], crop_size=96)
2020-02-18 03:05:44 +08:00
2020-03-01 06:32:37 +08:00
train_dataset = datasets.STL10('./data', split='train', download=True, transform=DataTransform(data_augment))
# train_dataset = datasets.Caltech101(root='./data', target_type="category", transform=transforms.ToTensor(),
# target_transform=None, download=True)
2020-02-25 05:33:28 +08:00
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=config['num_workers'], drop_last=True,
shuffle=True)
2020-02-18 03:05:44 +08:00
# model = Encoder(out_dim=out_dim)
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
2020-02-18 03:05:44 +08:00
2020-03-01 06:51:00 +08:00
train_gpu = False # torch.cuda.is_available()
2020-02-18 03:05:44 +08:00
print("Is gpu available:", train_gpu)
2020-02-29 18:53:14 +08:00
# moves the model parameters to gpu
2020-02-18 03:05:44 +08:00
if train_gpu:
model.cuda()
2020-02-20 10:27:27 +08:00
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
2020-02-18 03:05:44 +08:00
optimizer = optim.Adam(model.parameters(), 3e-4)
2020-02-18 10:17:10 +08:00
train_writer = SummaryWriter()
sim_func_dim1, sim_func_dim2 = get_similarity_function(use_cosine_similarity)
2020-02-20 06:12:07 +08:00
# Mask to remove positive examples from the batch of negative samples
2020-02-25 05:23:44 +08:00
negative_mask = get_negative_mask(batch_size)
2020-02-18 10:17:10 +08:00
n_iter = 0
2020-02-20 06:12:07 +08:00
for e in range(config['epochs']):
2020-03-01 06:32:37 +08:00
for step, ((xis, xjs), _) in enumerate(train_loader):
2020-02-20 06:12:07 +08:00
2020-03-01 06:51:00 +08:00
optimizer.zero_grad()
2020-02-18 10:17:10 +08:00
if train_gpu:
xis = xis.cuda()
xjs = xjs.cuda()
2020-02-18 03:05:44 +08:00
# get the representations and the projections
2020-02-18 10:17:10 +08:00
ris, zis = model(xis) # [N,C]
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
2020-02-18 03:05:44 +08:00
# get the representations and the projections
2020-02-18 10:17:10 +08:00
rjs, zjs = model(xjs) # [N,C]
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
2020-02-18 03:05:44 +08:00
# normalize projection feature vectors
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
2020-02-20 06:12:07 +08:00
# assert zis.shape == (batch_size, out_dim), "Shape not expected: " + str(zis.shape)
# assert zjs.shape == (batch_size, out_dim), "Shape not expected: " + str(zjs.shape)
l_pos = sim_func_dim1(zis, zjs).view(batch_size, 1)
l_pos /= temperature
# assert l_pos.shape == (batch_size, 1), "l_pos shape not valid" + str(l_pos.shape) # [N,1]
2020-02-20 06:12:07 +08:00
negatives = torch.cat([zjs, zis], dim=0)
loss = 0
for positives in [zis, zjs]:
l_neg = sim_func_dim2(positives, negatives)
labels = torch.zeros(batch_size, dtype=torch.long)
if train_gpu:
labels = labels.cuda()
2020-02-18 03:05:44 +08:00
2020-02-20 10:27:27 +08:00
l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
l_neg /= temperature
2020-02-25 05:34:11 +08:00
# assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(
# l_neg.shape)
2020-02-20 10:27:27 +08:00
logits = torch.cat([l_pos, l_neg], dim=1) # [N,K+1]
loss += criterion(logits, labels)
loss = loss / (2 * batch_size)
2020-02-18 10:17:10 +08:00
train_writer.add_scalar('loss', loss, global_step=n_iter)
2020-02-18 03:05:44 +08:00
loss.backward()
optimizer.step()
2020-02-18 10:17:10 +08:00
n_iter += 1
# print("Step {}, Loss {}".format(step, loss))
2020-02-18 03:05:44 +08:00
2020-02-25 05:23:44 +08:00
torch.save(model.state_dict(), './checkpoints/checkpoint.pth')