mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
added mixed precision tranining
This commit is contained in:
parent
89abbabd56
commit
68d57a13c7
@ -4,6 +4,7 @@ eval_every_n_epochs: 1
|
|||||||
fine_tune_from: None
|
fine_tune_from: None
|
||||||
log_every_n_steps: 50
|
log_every_n_steps: 50
|
||||||
weight_decay: 10e-6
|
weight_decay: 10e-6
|
||||||
|
opt_level: 'O0'
|
||||||
|
|
||||||
model:
|
model:
|
||||||
out_dim: 256
|
out_dim: 256
|
||||||
|
@ -10,9 +10,9 @@ class NTXentLoss(torch.nn.Module):
|
|||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.device = device
|
self.device = device
|
||||||
self.softmax = torch.nn.Softmax(dim=-1)
|
self.softmax = torch.nn.Softmax(dim=-1)
|
||||||
self.mask_samples_from_same_repr = self._get_correlated_mask()
|
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
|
||||||
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
||||||
self.labels = self._get_labels()
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
||||||
|
|
||||||
def _get_similarity_function(self, use_cosine_similarity):
|
def _get_similarity_function(self, use_cosine_similarity):
|
||||||
if use_cosine_similarity:
|
if use_cosine_similarity:
|
||||||
@ -21,14 +21,13 @@ class NTXentLoss(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self._dot_simililarity
|
return self._dot_simililarity
|
||||||
|
|
||||||
def _get_labels(self):
|
|
||||||
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=-self.batch_size)
|
|
||||||
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=self.batch_size - 1)
|
|
||||||
labels = torch.from_numpy((l1 + l2).astype(np.int))
|
|
||||||
return labels.to(self.device)
|
|
||||||
|
|
||||||
def _get_correlated_mask(self):
|
def _get_correlated_mask(self):
|
||||||
return (1 - torch.eye(2 * self.batch_size)).type(torch.bool)
|
diag = np.eye(2 * self.batch_size)
|
||||||
|
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
|
||||||
|
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
|
||||||
|
mask = torch.from_numpy((diag + l1 + l2))
|
||||||
|
mask = (1 - mask).type(torch.bool)
|
||||||
|
return mask.to(self.device)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dot_simililarity(x, y):
|
def _dot_simililarity(x, y):
|
||||||
@ -46,14 +45,21 @@ class NTXentLoss(torch.nn.Module):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
def forward(self, zis, zjs):
|
def forward(self, zis, zjs):
|
||||||
negatives = torch.cat([zjs, zis], dim=0)
|
representations = torch.cat([zjs, zis], dim=0)
|
||||||
|
|
||||||
logits = self.similarity_function(negatives, negatives)
|
similarity_matrix = self.similarity_function(representations, representations)
|
||||||
logits = logits[self.mask_samples_from_same_repr.type(torch.bool)].view(2 * self.batch_size, -1)
|
|
||||||
|
# filter out the scores from the positive samples
|
||||||
|
l_pos = torch.diag(similarity_matrix, self.batch_size)
|
||||||
|
r_pos = torch.diag(similarity_matrix, -self.batch_size)
|
||||||
|
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
|
||||||
|
|
||||||
|
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
|
||||||
|
|
||||||
|
logits = torch.cat((positives, negatives), dim=1)
|
||||||
logits /= self.temperature
|
logits /= self.temperature
|
||||||
# assert logits.shape == (2 * self.batch_size, 2 * self.batch_size - 1), "Shape of negatives not expected." + str(
|
|
||||||
# logits.shape)
|
|
||||||
|
|
||||||
probs = self.softmax(logits)
|
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
|
||||||
loss = torch.mean(-torch.sum(self.labels * torch.log(probs), dim=-1))
|
loss = self.criterion(logits, labels)
|
||||||
return loss
|
|
||||||
|
return loss / (2 * self.batch_size)
|
||||||
|
23
simclr.py
23
simclr.py
@ -5,6 +5,14 @@ import torch.nn.functional as F
|
|||||||
from loss.nt_xent import NTXentLoss
|
from loss.nt_xent import NTXentLoss
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.path.append('./apex')
|
||||||
|
from apex import amp
|
||||||
|
except:
|
||||||
|
raise ("Please install apex for mixed precision training")
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
@ -43,13 +51,6 @@ class SimCLR(object):
|
|||||||
zjs = F.normalize(zjs, dim=1)
|
zjs = F.normalize(zjs, dim=1)
|
||||||
|
|
||||||
loss = self.nt_xent_criterion(zis, zjs)
|
loss = self.nt_xent_criterion(zis, zjs)
|
||||||
|
|
||||||
if n_iter % self.config['log_every_n_steps'] == 0:
|
|
||||||
self.writer.add_histogram("xi_repr", ris, global_step=n_iter)
|
|
||||||
self.writer.add_histogram("xi_latent", zis, global_step=n_iter)
|
|
||||||
self.writer.add_histogram("xj_repr", rjs, global_step=n_iter)
|
|
||||||
self.writer.add_histogram("xj_latent", zjs, global_step=n_iter)
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
@ -64,6 +65,10 @@ class SimCLR(object):
|
|||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
|
||||||
last_epoch=-1)
|
last_epoch=-1)
|
||||||
|
|
||||||
|
model, optimizer = amp.initialize(model, optimizer,
|
||||||
|
opt_level=self.config['opt_level'],
|
||||||
|
keep_batchnorm_fp32=True)
|
||||||
|
|
||||||
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
||||||
|
|
||||||
# save config file
|
# save config file
|
||||||
@ -85,7 +90,9 @@ class SimCLR(object):
|
|||||||
if n_iter % self.config['log_every_n_steps'] == 0:
|
if n_iter % self.config['log_every_n_steps'] == 0:
|
||||||
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
||||||
|
|
||||||
loss.backward()
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user