complete the loss function description from the paper

pull/5/head
Thalles Silva 2020-02-24 15:36:10 -03:00
parent 186fb0159d
commit 304376b6a8
4 changed files with 18 additions and 7 deletions

View File

@ -1 +1,2 @@
# SimCLR
# PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

View File

@ -2,5 +2,6 @@ batch_size: 256
out_dim: 64
s: 1
temperature: 0.5
base_convnet: "resnet18" # one of: "resnet18 or resnet50"
use_cosine_similarity: True
epochs: 30

View File

@ -43,11 +43,14 @@ class Encoder(nn.Module):
return h, x
class ResNet18(nn.Module):
class ResNetSimCLR(nn.Module):
def __init__(self, out_dim=64):
super(ResNet18, self).__init__()
resnet = models.resnet18(pretrained=False)
def __init__(self, base_model="resnet18", out_dim=64):
super(ResNetSimCLR, self).__init__()
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
"resnet50": models.resnet50(pretrained=False)}
resnet = self._get_basemodel(base_model)
num_ftrs = resnet.fc.in_features
self.features = nn.Sequential(*list(resnet.children())[:-1])
@ -56,6 +59,12 @@ class ResNet18(nn.Module):
self.l1 = nn.Linear(num_ftrs, num_ftrs)
self.l2 = nn.Linear(num_ftrs, out_dim)
def _get_basemodel(self, model_name):
try:
return self.resnet_dict[model_name]
except:
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
def forward(self, x):
h = self.features(x)
h = h.squeeze()

View File

@ -9,7 +9,7 @@ from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from model import ResNet18
from model import ResNetSimCLR
from utils import GaussianBlur
torch.manual_seed(0)
@ -36,7 +36,7 @@ train_dataset = datasets.STL10('data', split='train', download=True, transform=t
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True)
# model = Encoder(out_dim=out_dim)
model = ResNet18(out_dim=out_dim)
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
print(model)
train_gpu = torch.cuda.is_available()