PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
 
 
Go to file
Thalles b1e8f2c3d9 Merge remote-tracking branch 'origin/master' 2020-03-14 18:33:54 -03:00
data_aug new pythonic implementation 2020-03-13 22:56:04 -03:00
feature_eval Created using Colaboratory 2020-03-14 11:34:51 -03:00
loss minor improvements 2020-03-14 07:01:49 -03:00
models new pythonic implementation 2020-03-13 22:56:04 -03:00
README.md Update README.md 2020-03-14 11:35:26 -03:00
config.yaml resnet-50 config file 2020-03-14 18:33:46 -03:00
requirements.txt added requirements.txt file 2020-03-13 23:13:54 -03:00
run.py minor improvements 2020-03-14 07:01:49 -03:00
simclr.py new results 2020-03-14 11:32:59 -03:00

README.md

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Blog post with full documentation: Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

For a Tensorflow 2.0 Implementation: Tensorflow SimCLR

Image of SimCLR Arch

Installation

$ conda create --name simclr python=3.7 --file requirements.txt
$ conda activate simclr
$ python run.py

Config file

Before running SimCLR, make sure you choose the correct running configurations on the config.yaml file.


# A batch size of N, produces 2 * (N-1) negative samples. Original implementation uses a batch size of 8192
batch_size: 512 

# Number of epochs to train
epochs: 40

# Frequency to eval the similarity score using the validation set
eval_every_n_epochs: 1

# Specify a folder containing a pre-trained model to fine-tune
fine_tune_from: 'Mar13_20-12-09_thallessilva'

# Frequency to which tensorboard is updated
log_every_n_steps: 50

# Model related parameters
model:
  # Output dimensionality of the embedding vector z. Original implementation uses 2048
  out_dim: 256 
  
  # The ConvNet base model. Choose one of: "resnet18" or "resnet50". Original implementation uses resnet50
  base_model: "resnet18"

# Dataset related parameters
dataset:
  s: 1
  
  # dataset input shape. For datasets containing images of different size, this defines the final 
  input_shape: (96,96,3) 
  
  # Number of workers for the data loader
  num_workers: 0
  
  # Size of the validation set in percentage
  valid_size: 0.05

# NTXent loss related parameters
loss:
  # Temperature parameter for the contrastive objective
  temperature: 0.5 
  
  # Distance metric for contrastive loss. If False, uses dot product. Original implementation uses cosine similarity.
  use_cosine_similarity: True

Feature Evaluation

Feature evaluation is done using a linear model protocol.

Features are learned using the STL10 train+unsupervised set and evaluated in the test set;

Check the Open In Colab notebook for reproducibility.

Linear Classifier Feature Extractor Architecture Feature dimensionality Projection Head dimensionality Epochs STL10 Top 1
Logistic Regression PCA Features - 256 - 36.0%
KNN PCA Features - 256 - 31.8%
Logistic Regression (LBFGS) SimCLR ResNet-18 512 256 40 70.3%
KNN SimCLR ResNet-18 512 256 40 66.2%
Logistic Regression (LBFGS) SimCLR ResNet-18 512 256 80 72.9%
KNN SimCLR ResNet-18 512 256 80 69.8%
Logistic Regression (LBFGS) SimCLR ResNet-50 2048 - 40 -