mirror of https://github.com/sthalles/SimCLR.git
31 lines
1.0 KiB
Python
31 lines
1.0 KiB
Python
import torch.nn as nn
|
|
import torchvision.models as models
|
|
|
|
from exceptions.exceptions import InvalidBackboneError
|
|
|
|
|
|
class ResNetSimCLR(nn.Module):
|
|
|
|
def __init__(self, base_model, out_dim):
|
|
super(ResNetSimCLR, self).__init__()
|
|
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
|
|
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
|
|
|
|
self.backbone = self._get_basemodel(base_model)
|
|
dim_mlp = self.backbone.fc.in_features
|
|
|
|
# add mlp projection head
|
|
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
|
|
|
|
def _get_basemodel(self, model_name):
|
|
try:
|
|
model = self.resnet_dict[model_name]
|
|
except KeyError:
|
|
raise InvalidBackboneError(
|
|
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
|
|
else:
|
|
return model
|
|
|
|
def forward(self, x):
|
|
return self.backbone(x)
|