SimCLR/models/resnet_simclr.py

37 lines
1.0 KiB
Python
Raw Normal View History

2020-02-18 03:05:44 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
2020-02-20 06:12:07 +08:00
import torchvision.models as models
2020-02-18 03:05:44 +08:00
class ResNetSimCLR(nn.Module):
2020-02-20 06:12:07 +08:00
def __init__(self, base_model="resnet18", out_dim=64):
super(ResNetSimCLR, self).__init__()
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
"resnet50": models.resnet50(pretrained=False)}
resnet = self._get_basemodel(base_model)
2020-02-20 06:12:07 +08:00
num_ftrs = resnet.fc.in_features
self.features = nn.Sequential(*list(resnet.children())[:-1])
# projection MLP
self.l1 = nn.Linear(num_ftrs, num_ftrs)
self.l2 = nn.Linear(num_ftrs, out_dim)
def _get_basemodel(self, model_name):
try:
return self.resnet_dict[model_name]
except:
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
2020-02-20 06:12:07 +08:00
def forward(self, x):
h = self.features(x)
h = h.squeeze()
x = self.l1(h)
x = F.relu(x)
x = self.l2(x)
return h, x