2020-03-14 09:56:04 +08:00
|
|
|
import torch
|
|
|
|
from models.resnet_simclr import ResNetSimCLR
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from loss.nt_xent import NTXentLoss
|
|
|
|
import os
|
|
|
|
import shutil
|
2020-03-16 08:55:00 +08:00
|
|
|
import sys
|
|
|
|
|
2020-03-16 18:25:53 +08:00
|
|
|
apex_support = False
|
2020-03-16 08:55:00 +08:00
|
|
|
try:
|
|
|
|
sys.path.append('./apex')
|
|
|
|
from apex import amp
|
2020-03-16 18:25:53 +08:00
|
|
|
|
|
|
|
apex_support = True
|
2020-03-16 08:55:00 +08:00
|
|
|
except:
|
2020-03-16 18:25:53 +08:00
|
|
|
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
|
|
|
|
apex_support = False
|
2020-03-16 08:55:00 +08:00
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
|
|
|
|
2020-03-14 18:01:49 +08:00
|
|
|
def _save_config_file(model_checkpoints_folder):
|
|
|
|
if not os.path.exists(model_checkpoints_folder):
|
|
|
|
os.makedirs(model_checkpoints_folder)
|
|
|
|
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
|
|
|
|
|
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
class SimCLR(object):
|
|
|
|
|
2020-03-14 18:01:49 +08:00
|
|
|
def __init__(self, dataset, config):
|
2020-03-14 09:56:04 +08:00
|
|
|
self.config = config
|
|
|
|
self.device = self._get_device()
|
|
|
|
self.writer = SummaryWriter()
|
2020-03-14 18:01:49 +08:00
|
|
|
self.dataset = dataset
|
2020-03-14 09:56:04 +08:00
|
|
|
self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
|
|
|
|
|
|
|
|
def _get_device(self):
|
2020-03-14 18:01:49 +08:00
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
2020-03-14 09:56:04 +08:00
|
|
|
print("Running on:", device)
|
|
|
|
return device
|
|
|
|
|
|
|
|
def _step(self, model, xis, xjs, n_iter):
|
2020-03-14 18:01:49 +08:00
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
# get the representations and the projections
|
|
|
|
ris, zis = model(xis) # [N,C]
|
|
|
|
|
|
|
|
# get the representations and the projections
|
|
|
|
rjs, zjs = model(xjs) # [N,C]
|
|
|
|
|
|
|
|
# normalize projection feature vectors
|
|
|
|
zis = F.normalize(zis, dim=1)
|
|
|
|
zjs = F.normalize(zjs, dim=1)
|
|
|
|
|
|
|
|
loss = self.nt_xent_criterion(zis, zjs)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def train(self):
|
2020-03-14 18:01:49 +08:00
|
|
|
|
|
|
|
train_loader, valid_loader = self.dataset.get_data_loaders()
|
2020-03-14 09:56:04 +08:00
|
|
|
|
|
|
|
model = ResNetSimCLR(**self.config["model"]).to(self.device)
|
|
|
|
model = self._load_pre_trained_weights(model)
|
|
|
|
|
2020-03-15 21:10:02 +08:00
|
|
|
optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay']))
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2020-03-16 00:14:14 +08:00
|
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
|
|
|
|
last_epoch=-1)
|
|
|
|
|
2020-03-16 18:25:53 +08:00
|
|
|
if apex_support and self.config['fp16_precision']:
|
|
|
|
model, optimizer = amp.initialize(model, optimizer,
|
|
|
|
opt_level='O2',
|
|
|
|
keep_batchnorm_fp32=True)
|
2020-03-16 08:55:00 +08:00
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
|
|
|
|
|
|
|
# save config file
|
2020-03-14 18:01:49 +08:00
|
|
|
_save_config_file(model_checkpoints_folder)
|
2020-03-14 09:56:04 +08:00
|
|
|
|
|
|
|
n_iter = 0
|
|
|
|
valid_n_iter = 0
|
|
|
|
best_valid_loss = np.inf
|
|
|
|
|
|
|
|
for epoch_counter in range(self.config['epochs']):
|
|
|
|
for (xis, xjs), _ in train_loader:
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
xis = xis.to(self.device)
|
|
|
|
xjs = xjs.to(self.device)
|
|
|
|
|
|
|
|
loss = self._step(model, xis, xjs, n_iter)
|
|
|
|
|
2020-03-16 00:14:14 +08:00
|
|
|
if n_iter % self.config['log_every_n_steps'] == 0:
|
|
|
|
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
|
|
|
|
2020-03-16 18:25:53 +08:00
|
|
|
if apex_support and self.config['fp16_precision']:
|
|
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
|
scaled_loss.backward()
|
|
|
|
else:
|
|
|
|
loss.backward()
|
2020-03-16 08:55:00 +08:00
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
optimizer.step()
|
|
|
|
n_iter += 1
|
|
|
|
|
2020-03-16 00:14:14 +08:00
|
|
|
# validate the model if requested
|
2020-03-14 09:56:04 +08:00
|
|
|
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
|
2020-03-16 00:14:14 +08:00
|
|
|
valid_loss = self._validate(model, valid_loader)
|
|
|
|
if valid_loss < best_valid_loss:
|
|
|
|
# save the model weights
|
|
|
|
best_valid_loss = valid_loss
|
|
|
|
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2020-03-16 00:14:14 +08:00
|
|
|
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
|
|
|
valid_n_iter += 1
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2020-03-16 00:14:14 +08:00
|
|
|
# warmup for the first 10 epochs
|
|
|
|
if epoch_counter >= 10:
|
|
|
|
scheduler.step()
|
|
|
|
self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
|
2020-03-14 09:56:04 +08:00
|
|
|
|
|
|
|
def _load_pre_trained_weights(self, model):
|
|
|
|
try:
|
|
|
|
checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
|
|
|
|
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
print("Loaded pre-trained model with success.")
|
|
|
|
except FileNotFoundError:
|
|
|
|
print("Pre-trained weights not found. Training from scratch.")
|
|
|
|
|
|
|
|
return model
|
2020-03-16 00:14:14 +08:00
|
|
|
|
|
|
|
def _validate(self, model, valid_loader):
|
|
|
|
|
|
|
|
# validation steps
|
|
|
|
with torch.no_grad():
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
valid_loss = 0.0
|
2020-04-18 22:40:10 +08:00
|
|
|
counter = 0
|
|
|
|
for (xis, xjs), _ in valid_loader:
|
2020-03-16 00:14:14 +08:00
|
|
|
xis = xis.to(self.device)
|
|
|
|
xjs = xjs.to(self.device)
|
|
|
|
|
|
|
|
loss = self._step(model, xis, xjs, counter)
|
|
|
|
valid_loss += loss.item()
|
2020-04-18 22:40:10 +08:00
|
|
|
counter += 1
|
2020-03-16 00:14:14 +08:00
|
|
|
valid_loss /= counter
|
|
|
|
model.train()
|
|
|
|
return valid_loss
|