SimCLR/train.py

148 lines
4.7 KiB
Python
Raw Normal View History

import shutil
2020-02-17 16:05:44 -03:00
import torch
2020-02-19 23:27:27 -03:00
import yaml
2020-02-17 23:17:10 -03:00
print(torch.__version__)
2020-02-17 16:05:44 -03:00
import torch.optim as optim
import os
2020-03-10 10:35:11 -03:00
2020-02-17 16:05:44 -03:00
from torchvision import datasets
2020-02-17 23:17:10 -03:00
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
2020-03-10 10:35:11 -03:00
import numpy as np
2020-02-24 18:23:44 -03:00
from models.resnet_simclr import ResNetSimCLR
from utils import get_similarity_function, get_train_validation_data_loaders
2020-03-12 22:34:21 -03:00
from data_aug.data_transform import DataTransform, get_simclr_data_transform
2020-02-17 16:05:44 -03:00
torch.manual_seed(0)
2020-03-10 10:35:11 -03:00
np.random.seed(0)
2020-02-19 23:27:27 -03:00
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
2020-02-19 19:12:07 -03:00
batch_size = config['batch_size']
out_dim = config['out_dim']
temperature = config['temperature']
use_cosine_similarity = config['use_cosine_similarity']
2020-02-17 16:05:44 -03:00
2020-03-12 22:34:21 -03:00
data_augment = get_simclr_data_transform(s=config['s'], crop_size=96)
2020-02-17 16:05:44 -03:00
train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, transform=DataTransform(data_augment))
2020-03-12 22:34:21 -03:00
train_loader, valid_loader = get_train_validation_data_loaders(train_dataset, **config)
2020-02-17 16:05:44 -03:00
# model = Encoder(out_dim=out_dim)
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
2020-02-17 16:05:44 -03:00
2020-03-12 22:34:21 -03:00
if config['continue_training']:
checkpoints_folder = os.path.join('./runs', config['continue_training'], 'checkpoints')
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
model.load_state_dict(state_dict)
2020-03-12 22:34:21 -03:00
print("Loaded pre-trained model with success.")
2020-03-11 07:20:43 -03:00
2020-03-01 10:19:41 -03:00
train_gpu = torch.cuda.is_available()
2020-02-17 16:05:44 -03:00
print("Is gpu available:", train_gpu)
2020-02-29 07:53:14 -03:00
# moves the model parameters to gpu
2020-02-17 16:05:44 -03:00
if train_gpu:
2020-03-12 22:34:21 -03:00
model = model.cuda()
2020-02-17 16:05:44 -03:00
2020-02-19 23:27:27 -03:00
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
2020-02-17 16:05:44 -03:00
optimizer = optim.Adam(model.parameters(), 3e-4)
2020-02-17 23:17:10 -03:00
train_writer = SummaryWriter()
2020-03-12 22:34:21 -03:00
similarity_func = get_similarity_function(use_cosine_similarity)
megative_mask = (1 - torch.eye(2 * batch_size)).type(torch.bool)
labels = (np.eye((2 * batch_size), 2 * batch_size - 1, k=-batch_size) + np.eye((2 * batch_size), 2 * batch_size - 1,
k=batch_size - 1)).astype(np.int)
labels = torch.from_numpy(labels)
softmax = torch.nn.Softmax(dim=-1)
if train_gpu:
2020-03-12 22:34:21 -03:00
labels = labels.cuda()
2020-02-19 19:12:07 -03:00
2020-03-10 10:35:11 -03:00
def step(xis, xjs):
# get the representations and the projections
ris, zis = model(xis) # [N,C]
2020-02-29 19:51:00 -03:00
2020-03-10 10:35:11 -03:00
# get the representations and the projections
rjs, zjs = model(xjs) # [N,C]
2020-03-12 22:34:21 -03:00
if n_iter % config['log_every_n_steps'] == 0:
train_writer.add_histogram("xi_repr", ris, global_step=n_iter)
train_writer.add_histogram("xi_latent", zis, global_step=n_iter)
train_writer.add_histogram("xj_repr", rjs, global_step=n_iter)
train_writer.add_histogram("xj_latent", zjs, global_step=n_iter)
2020-02-17 16:05:44 -03:00
2020-03-10 10:35:11 -03:00
# normalize projection feature vectors
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
2020-02-17 16:05:44 -03:00
2020-03-10 10:35:11 -03:00
negatives = torch.cat([zjs, zis], dim=0)
logits = similarity_func(negatives, negatives)
logits = logits[megative_mask.type(torch.bool)].view(2 * batch_size, -1)
logits /= temperature
# assert logits.shape == (2 * batch_size, 2 * batch_size - 1), "Shape of negatives not expected." + str(
# logits.shape)
probs = softmax(logits)
loss = torch.mean(-torch.sum(labels * torch.log(probs), dim=-1))
2020-02-17 16:05:44 -03:00
2020-03-10 10:35:11 -03:00
return loss
2020-02-19 23:27:27 -03:00
2020-03-10 10:35:11 -03:00
model_checkpoints_folder = os.path.join(train_writer.log_dir, 'checkpoints')
if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
2020-03-11 07:20:43 -03:00
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
2020-02-17 16:05:44 -03:00
2020-03-10 10:35:11 -03:00
n_iter = 0
2020-03-10 20:41:59 -03:00
valid_n_iter = 0
2020-03-10 10:35:11 -03:00
best_valid_loss = np.inf
for epoch_counter in range(config['epochs']):
for (xis, xjs), _ in train_loader:
optimizer.zero_grad()
if train_gpu:
xis = xis.cuda()
xjs = xjs.cuda()
loss = step(xis, xjs)
train_writer.add_scalar('train_loss', loss, global_step=n_iter)
2020-02-17 16:05:44 -03:00
loss.backward()
optimizer.step()
2020-02-17 23:17:10 -03:00
n_iter += 1
2020-02-17 16:05:44 -03:00
2020-03-10 10:35:11 -03:00
if epoch_counter % config['eval_every_n_epochs'] == 0:
# validation steps
with torch.no_grad():
model.eval()
2020-03-10 20:41:59 -03:00
valid_loss = 0.0
for counter, ((xis, xjs), _) in enumerate(valid_loader):
2020-03-10 10:35:11 -03:00
if train_gpu:
xis = xis.cuda()
xjs = xjs.cuda()
2020-03-10 20:41:59 -03:00
loss = (step(xis, xjs))
valid_loss += loss.item()
valid_loss /= counter
2020-03-10 10:41:11 -03:00
2020-03-10 20:41:59 -03:00
if valid_loss < best_valid_loss:
# save the model weights
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
2020-03-10 10:35:11 -03:00
2020-03-10 20:41:59 -03:00
train_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
valid_n_iter += 1
2020-03-10 10:35:11 -03:00
model.train()