mirror of https://github.com/sthalles/SimCLR.git
124 lines
4.9 KiB
Python
124 lines
4.9 KiB
Python
import logging
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
from utils import save_config_file, accuracy, save_checkpoint
|
|
|
|
torch.manual_seed(0)
|
|
|
|
apex_support = False
|
|
try:
|
|
sys.path.append('./apex')
|
|
from apex import amp
|
|
|
|
apex_support = True
|
|
except:
|
|
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
|
|
apex_support = False
|
|
|
|
|
|
class SimCLR(object):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.args = kwargs['args']
|
|
self.model = kwargs['model'].to(self.args.device)
|
|
self.optimizer = kwargs['optimizer']
|
|
self.scheduler = kwargs['scheduler']
|
|
self.writer = SummaryWriter()
|
|
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
|
|
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
|
|
|
|
def info_nce_loss(self, features):
|
|
|
|
labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
|
|
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
|
|
labels = labels.to(self.args.device)
|
|
|
|
features = F.normalize(features, dim=1)
|
|
|
|
similarity_matrix = torch.matmul(features, features.T)
|
|
# assert similarity_matrix.shape == (
|
|
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
|
|
# assert similarity_matrix.shape == labels.shape
|
|
|
|
# discard the main diagonal from both: labels and similarities matrix
|
|
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
|
|
labels = labels[~mask].view(labels.shape[0], -1)
|
|
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
|
|
# assert similarity_matrix.shape == labels.shape
|
|
|
|
# select and combine multiple positives
|
|
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
|
|
|
|
# select only the negatives the negatives
|
|
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
|
|
|
|
logits = torch.cat([positives, negatives], dim=1)
|
|
labels = torch.zeros(logits.shape[0]).to(self.args.device)
|
|
|
|
logits = logits / self.args.temperature
|
|
return logits, labels
|
|
|
|
def train(self, train_loader):
|
|
|
|
if apex_support and self.args.fp16_precision:
|
|
logging.debug("Using apex for fp16 precision training.")
|
|
self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
|
|
opt_level='O2',
|
|
keep_batchnorm_fp32=True)
|
|
# save config file
|
|
save_config_file(self.writer.log_dir, self.args)
|
|
|
|
n_iter = 0
|
|
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
|
|
logging.info(f"Training with gpu: {self.args.disable_cuda}.")
|
|
|
|
for epoch_counter in range(self.args.epochs):
|
|
for images, _ in tqdm(train_loader):
|
|
images = torch.cat(images, dim=0)
|
|
|
|
images = images.to(self.args.device)
|
|
|
|
features = self.model(images)
|
|
logits, labels = self.info_nce_loss(features)
|
|
loss = self.criterion(logits, labels)
|
|
|
|
self.optimizer.zero_grad()
|
|
if apex_support and self.args.fp16_precision:
|
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
self.optimizer.step()
|
|
|
|
if n_iter % self.args.log_every_n_steps == 0:
|
|
top1, top5 = accuracy(logits, labels, topk=(1, 5))
|
|
self.writer.add_scalar('loss', loss, global_step=n_iter)
|
|
self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
|
|
self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
|
|
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
|
|
|
|
n_iter += 1
|
|
|
|
# warmup for the first 10 epochs
|
|
if epoch_counter >= 10:
|
|
self.scheduler.step()
|
|
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
|
|
|
|
logging.info("Training has finished.")
|
|
# save model checkpoints
|
|
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
|
|
save_checkpoint({
|
|
'epoch': self.args.epochs,
|
|
'arch': self.args.arch,
|
|
'state_dict': self.model.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
|
|
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")
|