mirror of https://github.com/sthalles/SimCLR.git
complete the loss function description from the paper
parent
186fb0159d
commit
304376b6a8
|
@ -1 +1,2 @@
|
||||||
# SimCLR
|
# PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
|
||||||
|
|
||||||
|
|
|
@ -2,5 +2,6 @@ batch_size: 256
|
||||||
out_dim: 64
|
out_dim: 64
|
||||||
s: 1
|
s: 1
|
||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
|
base_convnet: "resnet18" # one of: "resnet18 or resnet50"
|
||||||
use_cosine_similarity: True
|
use_cosine_similarity: True
|
||||||
epochs: 30
|
epochs: 30
|
17
model.py
17
model.py
|
@ -43,11 +43,14 @@ class Encoder(nn.Module):
|
||||||
return h, x
|
return h, x
|
||||||
|
|
||||||
|
|
||||||
class ResNet18(nn.Module):
|
class ResNetSimCLR(nn.Module):
|
||||||
|
|
||||||
def __init__(self, out_dim=64):
|
def __init__(self, base_model="resnet18", out_dim=64):
|
||||||
super(ResNet18, self).__init__()
|
super(ResNetSimCLR, self).__init__()
|
||||||
resnet = models.resnet18(pretrained=False)
|
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
|
||||||
|
"resnet50": models.resnet50(pretrained=False)}
|
||||||
|
|
||||||
|
resnet = self._get_basemodel(base_model)
|
||||||
num_ftrs = resnet.fc.in_features
|
num_ftrs = resnet.fc.in_features
|
||||||
|
|
||||||
self.features = nn.Sequential(*list(resnet.children())[:-1])
|
self.features = nn.Sequential(*list(resnet.children())[:-1])
|
||||||
|
@ -56,6 +59,12 @@ class ResNet18(nn.Module):
|
||||||
self.l1 = nn.Linear(num_ftrs, num_ftrs)
|
self.l1 = nn.Linear(num_ftrs, num_ftrs)
|
||||||
self.l2 = nn.Linear(num_ftrs, out_dim)
|
self.l2 = nn.Linear(num_ftrs, out_dim)
|
||||||
|
|
||||||
|
def _get_basemodel(self, model_name):
|
||||||
|
try:
|
||||||
|
return self.resnet_dict[model_name]
|
||||||
|
except:
|
||||||
|
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = self.features(x)
|
h = self.features(x)
|
||||||
h = h.squeeze()
|
h = h.squeeze()
|
||||||
|
|
4
train.py
4
train.py
|
@ -9,7 +9,7 @@ from torchvision import datasets
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from model import ResNet18
|
from model import ResNetSimCLR
|
||||||
from utils import GaussianBlur
|
from utils import GaussianBlur
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -36,7 +36,7 @@ train_dataset = datasets.STL10('data', split='train', download=True, transform=t
|
||||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True)
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=True, shuffle=True)
|
||||||
|
|
||||||
# model = Encoder(out_dim=out_dim)
|
# model = Encoder(out_dim=out_dim)
|
||||||
model = ResNet18(out_dim=out_dim)
|
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
train_gpu = torch.cuda.is_available()
|
train_gpu = torch.cuda.is_available()
|
||||||
|
|
Loading…
Reference in New Issue