diff --git a/README.md b/README.md index a2fc8fe..a9c142e 100644 --- a/README.md +++ b/README.md @@ -1 +1,2 @@ -# SimCLR \ No newline at end of file +# PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations + diff --git a/config.yaml b/config.yaml index 69e365e..11b9ca4 100644 --- a/config.yaml +++ b/config.yaml @@ -2,5 +2,6 @@ batch_size: 256 out_dim: 64 s: 1 temperature: 0.5 +base_convnet: "resnet18" # one of: "resnet18 or resnet50" use_cosine_similarity: True epochs: 30 \ No newline at end of file diff --git a/model.py b/model.py index 3cdfcc8..2a6cd7a 100644 --- a/model.py +++ b/model.py @@ -43,11 +43,14 @@ class Encoder(nn.Module): return h, x -class ResNet18(nn.Module): +class ResNetSimCLR(nn.Module): - def __init__(self, out_dim=64): - super(ResNet18, self).__init__() - resnet = models.resnet18(pretrained=False) + def __init__(self, base_model="resnet18", out_dim=64): + super(ResNetSimCLR, self).__init__() + self.resnet_dict = {"resnet18": models.resnet18(pretrained=False), + "resnet50": models.resnet50(pretrained=False)} + + resnet = self._get_basemodel(base_model) num_ftrs = resnet.fc.in_features self.features = nn.Sequential(*list(resnet.children())[:-1]) @@ -56,6 +59,12 @@ class ResNet18(nn.Module): self.l1 = nn.Linear(num_ftrs, num_ftrs) self.l2 = nn.Linear(num_ftrs, out_dim) + def _get_basemodel(self, model_name): + try: + return self.resnet_dict[model_name] + except: + raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") + def forward(self, x): h = self.features(x) h = h.squeeze() diff --git a/train.py b/train.py index 8548522..5cd3372 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ from torchvision import datasets from torch.utils.tensorboard import SummaryWriter import torch.nn.functional as F -from model import ResNet18 +from model import ResNetSimCLR from utils import GaussianBlur torch.manual_seed(0) @@ -36,7 +36,7 @@ train_dataset = datasets.STL10('data', split='train', download=True, transform=t train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True) # model = Encoder(out_dim=out_dim) -model = ResNet18(out_dim=out_dim) +model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim) print(model) train_gpu = torch.cuda.is_available()