diff --git a/config.yaml b/config.yaml index 785262d..5e60f58 100644 --- a/config.yaml +++ b/config.yaml @@ -1,13 +1,13 @@ -batch_size: 256 -epochs: 40 +batch_size: 512 +epochs: 80 eval_every_n_epochs: 1 fine_tune_from: None log_every_n_steps: 50 weight_decay: 10e-6 model: - out_dim: 128 - base_model: "resnet50" + out_dim: 256 + base_model: "resnet18" dataset: s: 1 diff --git a/simclr.py b/simclr.py index 9a354bd..be66a32 100644 --- a/simclr.py +++ b/simclr.py @@ -49,7 +49,6 @@ class SimCLR(object): self.writer.add_histogram("xi_latent", zis, global_step=n_iter) self.writer.add_histogram("xj_repr", rjs, global_step=n_iter) self.writer.add_histogram("xj_latent", zjs, global_step=n_iter) - self.writer.add_scalar('train_loss', loss, global_step=n_iter) return loss @@ -62,6 +61,9 @@ class SimCLR(object): optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay'])) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, + last_epoch=-1) + model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') # save config file @@ -80,35 +82,28 @@ class SimCLR(object): loss = self._step(model, xis, xjs, n_iter) + if n_iter % self.config['log_every_n_steps'] == 0: + self.writer.add_scalar('train_loss', loss, global_step=n_iter) + loss.backward() optimizer.step() n_iter += 1 + # validate the model if requested if epoch_counter % self.config['eval_every_n_epochs'] == 0: + valid_loss = self._validate(model, valid_loader) + if valid_loss < best_valid_loss: + # save the model weights + best_valid_loss = valid_loss + torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) - # validation steps - with torch.no_grad(): - model.eval() + self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) + valid_n_iter += 1 - valid_loss = 0.0 - for counter, ((xis, xjs), _) in enumerate(valid_loader): - xis = xis.to(self.device) - xjs = xjs.to(self.device) - - loss = self._step(model, xis, xjs, n_iter) - valid_loss += loss.item() - - valid_loss /= counter - - if valid_loss < best_valid_loss: - # save the model weights - best_valid_loss = valid_loss - torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) - - self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) - valid_n_iter += 1 - - model.train() + # warmup for the first 10 epochs + if epoch_counter >= 10: + scheduler.step() + self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter) def _load_pre_trained_weights(self, model): try: @@ -120,3 +115,21 @@ class SimCLR(object): print("Pre-trained weights not found. Training from scratch.") return model + + def _validate(self, model, valid_loader): + + # validation steps + with torch.no_grad(): + model.eval() + + valid_loss = 0.0 + for counter, ((xis, xjs), _) in enumerate(valid_loader): + xis = xis.to(self.device) + xjs = xjs.to(self.device) + + loss = self._step(model, xis, xjs, counter) + valid_loss += loss.item() + + valid_loss /= counter + model.train() + return valid_loss