From 68d57a13c70dfa21da70d03bf47035ebc5bf1eb6 Mon Sep 17 00:00:00 2001 From: Thalles Date: Sun, 15 Mar 2020 21:55:00 -0300 Subject: [PATCH] added mixed precision tranining --- config.yaml | 1 + loss/nt_xent.py | 40 +++++++++++++++++++++++----------------- simclr.py | 23 +++++++++++++++-------- 3 files changed, 39 insertions(+), 25 deletions(-) diff --git a/config.yaml b/config.yaml index 5e60f58..c357897 100644 --- a/config.yaml +++ b/config.yaml @@ -4,6 +4,7 @@ eval_every_n_epochs: 1 fine_tune_from: None log_every_n_steps: 50 weight_decay: 10e-6 +opt_level: 'O0' model: out_dim: 256 diff --git a/loss/nt_xent.py b/loss/nt_xent.py index 365dd24..ff2baff 100644 --- a/loss/nt_xent.py +++ b/loss/nt_xent.py @@ -10,9 +10,9 @@ class NTXentLoss(torch.nn.Module): self.temperature = temperature self.device = device self.softmax = torch.nn.Softmax(dim=-1) - self.mask_samples_from_same_repr = self._get_correlated_mask() + self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) self.similarity_function = self._get_similarity_function(use_cosine_similarity) - self.labels = self._get_labels() + self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") def _get_similarity_function(self, use_cosine_similarity): if use_cosine_similarity: @@ -21,14 +21,13 @@ class NTXentLoss(torch.nn.Module): else: return self._dot_simililarity - def _get_labels(self): - l1 = np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=-self.batch_size) - l2 = np.eye((2 * self.batch_size), 2 * self.batch_size - 1, k=self.batch_size - 1) - labels = torch.from_numpy((l1 + l2).astype(np.int)) - return labels.to(self.device) - def _get_correlated_mask(self): - return (1 - torch.eye(2 * self.batch_size)).type(torch.bool) + diag = np.eye(2 * self.batch_size) + l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) + l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) + mask = torch.from_numpy((diag + l1 + l2)) + mask = (1 - mask).type(torch.bool) + return mask.to(self.device) @staticmethod def _dot_simililarity(x, y): @@ -46,14 +45,21 @@ class NTXentLoss(torch.nn.Module): return v def forward(self, zis, zjs): - negatives = torch.cat([zjs, zis], dim=0) + representations = torch.cat([zjs, zis], dim=0) - logits = self.similarity_function(negatives, negatives) - logits = logits[self.mask_samples_from_same_repr.type(torch.bool)].view(2 * self.batch_size, -1) + similarity_matrix = self.similarity_function(representations, representations) + + # filter out the scores from the positive samples + l_pos = torch.diag(similarity_matrix, self.batch_size) + r_pos = torch.diag(similarity_matrix, -self.batch_size) + positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) + + negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) + + logits = torch.cat((positives, negatives), dim=1) logits /= self.temperature - # assert logits.shape == (2 * self.batch_size, 2 * self.batch_size - 1), "Shape of negatives not expected." + str( - # logits.shape) - probs = self.softmax(logits) - loss = torch.mean(-torch.sum(self.labels * torch.log(probs), dim=-1)) - return loss + labels = torch.zeros(2 * self.batch_size).to(self.device).long() + loss = self.criterion(logits, labels) + + return loss / (2 * self.batch_size) diff --git a/simclr.py b/simclr.py index be66a32..b2a58dc 100644 --- a/simclr.py +++ b/simclr.py @@ -5,6 +5,14 @@ import torch.nn.functional as F from loss.nt_xent import NTXentLoss import os import shutil +import sys + +try: + sys.path.append('./apex') + from apex import amp +except: + raise ("Please install apex for mixed precision training") + import numpy as np torch.manual_seed(0) @@ -43,13 +51,6 @@ class SimCLR(object): zjs = F.normalize(zjs, dim=1) loss = self.nt_xent_criterion(zis, zjs) - - if n_iter % self.config['log_every_n_steps'] == 0: - self.writer.add_histogram("xi_repr", ris, global_step=n_iter) - self.writer.add_histogram("xi_latent", zis, global_step=n_iter) - self.writer.add_histogram("xj_repr", rjs, global_step=n_iter) - self.writer.add_histogram("xj_latent", zjs, global_step=n_iter) - return loss def train(self): @@ -64,6 +65,10 @@ class SimCLR(object): scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1) + model, optimizer = amp.initialize(model, optimizer, + opt_level=self.config['opt_level'], + keep_batchnorm_fp32=True) + model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') # save config file @@ -85,7 +90,9 @@ class SimCLR(object): if n_iter % self.config['log_every_n_steps'] == 0: self.writer.add_scalar('train_loss', loss, global_step=n_iter) - loss.backward() + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + optimizer.step() n_iter += 1