2020-03-14 09:56:04 +08:00
|
|
|
import os
|
|
|
|
import shutil
|
2020-03-16 08:55:00 +08:00
|
|
|
import sys
|
2021-01-18 01:12:17 +08:00
|
|
|
import yaml
|
|
|
|
import torch
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import logging
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
torch.manual_seed(0)
|
2020-03-16 08:55:00 +08:00
|
|
|
|
2020-03-16 18:25:53 +08:00
|
|
|
apex_support = False
|
2020-03-16 08:55:00 +08:00
|
|
|
try:
|
|
|
|
sys.path.append('./apex')
|
|
|
|
from apex import amp
|
2020-03-16 18:25:53 +08:00
|
|
|
|
|
|
|
apex_support = True
|
2020-03-16 08:55:00 +08:00
|
|
|
except:
|
2020-03-16 18:25:53 +08:00
|
|
|
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
|
|
|
|
apex_support = False
|
2020-03-16 08:55:00 +08:00
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
|
|
|
torch.save(state, filename)
|
|
|
|
if is_best:
|
|
|
|
shutil.copyfile(filename, 'model_best.pth.tar')
|
2020-03-14 09:56:04 +08:00
|
|
|
|
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
def _save_config_file(model_checkpoints_folder, args):
|
2020-03-14 18:01:49 +08:00
|
|
|
if not os.path.exists(model_checkpoints_folder):
|
|
|
|
os.makedirs(model_checkpoints_folder)
|
2021-01-18 01:12:17 +08:00
|
|
|
with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
|
|
|
|
yaml.dump(args, outfile, default_flow_style=False)
|
2020-03-14 18:01:49 +08:00
|
|
|
|
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
class SimCLR(object):
|
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.args = kwargs['args']
|
|
|
|
self.model = kwargs['model'].to(self.args.device)
|
|
|
|
self.optimizer = kwargs['optimizer']
|
|
|
|
self.scheduler = kwargs['scheduler']
|
2020-03-14 09:56:04 +08:00
|
|
|
self.writer = SummaryWriter()
|
2021-01-18 01:12:17 +08:00
|
|
|
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
|
|
|
|
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
def info_nce_loss(self, features):
|
|
|
|
batch_targets = torch.arange(self.args.batch_size, dtype=torch.long).to(self.args.device)
|
|
|
|
batch_targets = torch.cat(self.args.n_views * [batch_targets])
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
features = F.normalize(features, dim=1)
|
2020-03-14 18:01:49 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
similarity_matrix = torch.matmul(features, features.T)
|
|
|
|
# assert similarity_matrix.shape == (
|
|
|
|
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
mask = torch.eye(len(batch_targets)).to(self.args.device)
|
|
|
|
similarities = similarity_matrix[~mask.bool()].view(similarity_matrix.shape[0], -1)
|
|
|
|
similarities = similarities / self.args.temperature
|
|
|
|
return similarities, batch_targets
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
def train(self, train_loader):
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
if apex_support and self.args.fp16_precision:
|
|
|
|
logging.debug("Using apex for fp16 precision training.")
|
|
|
|
self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
|
|
|
|
opt_level='O2',
|
|
|
|
keep_batchnorm_fp32=True)
|
2020-03-16 08:55:00 +08:00
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
|
|
|
|
|
|
|
# save config file
|
2021-01-18 01:12:17 +08:00
|
|
|
_save_config_file(model_checkpoints_folder, self.args)
|
2020-03-14 09:56:04 +08:00
|
|
|
|
|
|
|
n_iter = 0
|
2021-01-18 01:12:17 +08:00
|
|
|
logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
|
|
|
|
logging.info(f"Training with gpu: {self.args.disable_cuda}.")
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
for epoch_counter in range(self.args.epochs):
|
|
|
|
for images, _ in tqdm(train_loader):
|
|
|
|
images = torch.cat(images, dim=0)
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
images = images.to(self.args.device)
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
features = self.model(images)
|
|
|
|
logits, labels = self.info_nce_loss(features)
|
|
|
|
loss = self.criterion(logits, labels)
|
2020-03-16 00:14:14 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
self.optimizer.zero_grad()
|
|
|
|
if apex_support and self.args.fp16_precision:
|
|
|
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
2020-03-16 18:25:53 +08:00
|
|
|
scaled_loss.backward()
|
|
|
|
else:
|
|
|
|
loss.backward()
|
2020-03-16 08:55:00 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
self.optimizer.step()
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
if n_iter % self.args.log_every_n_steps == 0:
|
|
|
|
predictions = torch.argmax(logits, dim=1)
|
|
|
|
acc = 100 * (predictions == labels).float().mean()
|
|
|
|
self.writer.add_scalar('loss', loss, global_step=n_iter)
|
|
|
|
self.writer.add_scalar('acc/top1', acc, global_step=n_iter)
|
|
|
|
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
n_iter += 1
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2020-03-16 00:14:14 +08:00
|
|
|
# warmup for the first 10 epochs
|
|
|
|
if epoch_counter >= 10:
|
2021-01-18 01:12:17 +08:00
|
|
|
self.scheduler.step()
|
|
|
|
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {acc}")
|
|
|
|
|
|
|
|
logging.info("Training has finished.")
|
|
|
|
# save model checkpoints
|
|
|
|
checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
|
|
|
|
save_checkpoint({
|
|
|
|
'epoch': self.args.epochs,
|
|
|
|
'arch': self.args.arch,
|
|
|
|
'state_dict': self.model.state_dict(),
|
|
|
|
'optimizer': self.optimizer.state_dict(),
|
|
|
|
}, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
|
|
|
|
logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")
|