2020-02-18 03:05:44 +08:00
|
|
|
import torch.nn as nn
|
2020-02-20 06:12:07 +08:00
|
|
|
import torchvision.models as models
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
from exceptions.exceptions import InvalidBackboneError
|
|
|
|
|
2020-02-18 03:05:44 +08:00
|
|
|
|
2020-02-25 02:36:10 +08:00
|
|
|
class ResNetSimCLR(nn.Module):
|
2020-02-20 06:12:07 +08:00
|
|
|
|
2020-03-14 09:56:04 +08:00
|
|
|
def __init__(self, base_model, out_dim):
|
2020-02-25 02:36:10 +08:00
|
|
|
super(ResNetSimCLR, self).__init__()
|
2021-01-18 01:12:17 +08:00
|
|
|
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
|
|
|
|
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
|
2020-02-20 06:12:07 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
self.backbone = self._get_basemodel(base_model)
|
|
|
|
dim_mlp = self.backbone.fc.in_features
|
2020-02-20 06:12:07 +08:00
|
|
|
|
2021-01-18 01:12:17 +08:00
|
|
|
# add mlp projection head
|
|
|
|
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
|
2020-02-20 06:12:07 +08:00
|
|
|
|
2020-02-25 02:36:10 +08:00
|
|
|
def _get_basemodel(self, model_name):
|
|
|
|
try:
|
2020-03-14 09:56:04 +08:00
|
|
|
model = self.resnet_dict[model_name]
|
2021-01-18 01:12:17 +08:00
|
|
|
except KeyError:
|
|
|
|
raise InvalidBackboneError(
|
|
|
|
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
|
|
|
|
else:
|
2020-03-14 09:56:04 +08:00
|
|
|
return model
|
2020-02-25 02:36:10 +08:00
|
|
|
|
2020-02-20 06:12:07 +08:00
|
|
|
def forward(self, x):
|
2021-01-18 01:12:17 +08:00
|
|
|
return self.backbone(x)
|