mirror of https://github.com/sthalles/SimCLR.git
138 lines
5.2 KiB
Python
138 lines
5.2 KiB
Python
import logging
|
|
import os
|
|
import shutil
|
|
import sys
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import yaml
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
torch.manual_seed(0)
|
|
|
|
apex_support = False
|
|
try:
|
|
sys.path.append('./apex')
|
|
from apex import amp
|
|
|
|
apex_support = True
|
|
except:
|
|
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
|
|
apex_support = False
|
|
|
|
|
|
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
|
torch.save(state, filename)
|
|
if is_best:
|
|
shutil.copyfile(filename, 'model_best.pth.tar')
|
|
|
|
|
|
def _save_config_file(model_checkpoints_folder, args):
|
|
if not os.path.exists(model_checkpoints_folder):
|
|
os.makedirs(model_checkpoints_folder)
|
|
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
|
|
yaml.dump(args, outfile, default_flow_style=False)
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)):
|
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
with torch.no_grad():
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
res = []
|
|
for k in topk:
|
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
return res
|
|
|
|
|
|
class SimCLR(object):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.args = kwargs['args']
|
|
self.model = kwargs['model'].to(self.args.device)
|
|
self.optimizer = kwargs['optimizer']
|
|
self.scheduler = kwargs['scheduler']
|
|
self.writer = SummaryWriter()
|
|
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
|
|
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
|
|
|
|
def info_nce_loss(self, features):
|
|
batch_targets = torch.arange(self.args.batch_size, dtype=torch.long).to(self.args.device)
|
|
batch_targets = torch.cat(self.args.n_views * [batch_targets])
|
|
|
|
features = F.normalize(features, dim=1)
|
|
|
|
similarity_matrix = torch.matmul(features, features.T)
|
|
# assert similarity_matrix.shape == (
|
|
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
|
|
|
|
mask = torch.eye(len(batch_targets)).to(self.args.device)
|
|
similarities = similarity_matrix[~mask.bool()].view(similarity_matrix.shape[0], -1)
|
|
similarities = similarities / self.args.temperature
|
|
return similarities, batch_targets
|
|
|
|
def train(self, train_loader):
|
|
|
|
if apex_support and self.args.fp16_precision:
|
|
logging.debug("Using apex for fp16 precision training.")
|
|
self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
|
|
opt_level='O2',
|
|
keep_batchnorm_fp32=True)
|
|
# save config file
|
|
_save_config_file(self.writer.log_dir, self.args)
|
|
|
|
n_iter = 0
|
|
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
|
|
logging.info(f"Training with gpu: {self.args.disable_cuda}.")
|
|
|
|
for epoch_counter in range(self.args.epochs):
|
|
for images, _ in tqdm(train_loader):
|
|
images = torch.cat(images, dim=0)
|
|
|
|
images = images.to(self.args.device)
|
|
|
|
features = self.model(images)
|
|
logits, labels = self.info_nce_loss(features)
|
|
loss = self.criterion(logits, labels)
|
|
|
|
self.optimizer.zero_grad()
|
|
if apex_support and self.args.fp16_precision:
|
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
self.optimizer.step()
|
|
|
|
if n_iter % self.args.log_every_n_steps == 0:
|
|
top1, top5 = accuracy(logits, labels, topk=(1, 5))
|
|
self.writer.add_scalar('loss', loss, global_step=n_iter)
|
|
self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
|
|
self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
|
|
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
|
|
|
|
n_iter += 1
|
|
|
|
# warmup for the first 10 epochs
|
|
if epoch_counter >= 10:
|
|
self.scheduler.step()
|
|
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
|
|
|
|
logging.info("Training has finished.")
|
|
# save model checkpoints
|
|
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
|
|
save_checkpoint({
|
|
'epoch': self.args.epochs,
|
|
'arch': self.args.arch,
|
|
'state_dict': self.model.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
|
|
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")
|