SimCLR/models/resnet_simclr.py

31 lines
1.0 KiB
Python
Raw Normal View History

2020-02-17 16:05:44 -03:00
import torch.nn as nn
2020-02-19 19:12:07 -03:00
import torchvision.models as models
2020-02-17 16:05:44 -03:00
2021-01-17 14:12:17 -03:00
from exceptions.exceptions import InvalidBackboneError
2020-02-17 16:05:44 -03:00
class ResNetSimCLR(nn.Module):
2020-02-19 19:12:07 -03:00
2020-03-13 22:56:04 -03:00
def __init__(self, base_model, out_dim):
super(ResNetSimCLR, self).__init__()
2021-01-17 14:12:17 -03:00
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
2020-02-19 19:12:07 -03:00
2021-01-17 14:12:17 -03:00
self.backbone = self._get_basemodel(base_model)
dim_mlp = self.backbone.fc.in_features
2020-02-19 19:12:07 -03:00
2021-01-17 14:12:17 -03:00
# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
2020-02-19 19:12:07 -03:00
def _get_basemodel(self, model_name):
try:
2020-03-13 22:56:04 -03:00
model = self.resnet_dict[model_name]
2021-01-17 14:12:17 -03:00
except KeyError:
raise InvalidBackboneError(
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
else:
2020-03-13 22:56:04 -03:00
return model
2020-02-19 19:12:07 -03:00
def forward(self, x):
2021-01-17 14:12:17 -03:00
return self.backbone(x)