SimCLR/simclr.py

138 lines
5.2 KiB
Python

import logging
import os
import shutil
import sys
import torch
import torch.nn.functional as F
import yaml
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
torch.manual_seed(0)
apex_support = False
try:
sys.path.append('./apex')
from apex import amp
apex_support = True
except:
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
apex_support = False
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def _save_config_file(model_checkpoints_folder, args):
if not os.path.exists(model_checkpoints_folder):
os.makedirs(model_checkpoints_folder)
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
yaml.dump(args, outfile, default_flow_style=False)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class SimCLR(object):
def __init__(self, *args, **kwargs):
self.args = kwargs['args']
self.model = kwargs['model'].to(self.args.device)
self.optimizer = kwargs['optimizer']
self.scheduler = kwargs['scheduler']
self.writer = SummaryWriter()
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
def info_nce_loss(self, features):
batch_targets = torch.arange(self.args.batch_size, dtype=torch.long).to(self.args.device)
batch_targets = torch.cat(self.args.n_views * [batch_targets])
features = F.normalize(features, dim=1)
similarity_matrix = torch.matmul(features, features.T)
# assert similarity_matrix.shape == (
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
mask = torch.eye(len(batch_targets)).to(self.args.device)
similarities = similarity_matrix[~mask.bool()].view(similarity_matrix.shape[0], -1)
similarities = similarities / self.args.temperature
return similarities, batch_targets
def train(self, train_loader):
if apex_support and self.args.fp16_precision:
logging.debug("Using apex for fp16 precision training.")
self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
opt_level='O2',
keep_batchnorm_fp32=True)
# save config file
_save_config_file(self.writer.log_dir, self.args)
n_iter = 0
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
logging.info(f"Training with gpu: {self.args.disable_cuda}.")
for epoch_counter in range(self.args.epochs):
for images, _ in tqdm(train_loader):
images = torch.cat(images, dim=0)
images = images.to(self.args.device)
features = self.model(images)
logits, labels = self.info_nce_loss(features)
loss = self.criterion(logits, labels)
self.optimizer.zero_grad()
if apex_support and self.args.fp16_precision:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.optimizer.step()
if n_iter % self.args.log_every_n_steps == 0:
top1, top5 = accuracy(logits, labels, topk=(1, 5))
self.writer.add_scalar('loss', loss, global_step=n_iter)
self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
n_iter += 1
# warmup for the first 10 epochs
if epoch_counter >= 10:
self.scheduler.step()
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
logging.info("Training has finished.")
# save model checkpoints
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
save_checkpoint({
'epoch': self.args.epochs,
'arch': self.args.arch,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")