SimCLR/models/resnet_simclr.py

31 lines
1.0 KiB
Python
Raw Normal View History

2020-02-18 03:05:44 +08:00
import torch.nn as nn
2020-02-20 06:12:07 +08:00
import torchvision.models as models
2020-02-18 03:05:44 +08:00
2021-01-18 01:12:17 +08:00
from exceptions.exceptions import InvalidBackboneError
2020-02-18 03:05:44 +08:00
class ResNetSimCLR(nn.Module):
2020-02-20 06:12:07 +08:00
2020-03-14 09:56:04 +08:00
def __init__(self, base_model, out_dim):
super(ResNetSimCLR, self).__init__()
2021-01-18 01:12:17 +08:00
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
2020-02-20 06:12:07 +08:00
2021-01-18 01:12:17 +08:00
self.backbone = self._get_basemodel(base_model)
dim_mlp = self.backbone.fc.in_features
2020-02-20 06:12:07 +08:00
2021-01-18 01:12:17 +08:00
# add mlp projection head
self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
2020-02-20 06:12:07 +08:00
def _get_basemodel(self, model_name):
try:
2020-03-14 09:56:04 +08:00
model = self.resnet_dict[model_name]
2021-01-18 01:12:17 +08:00
except KeyError:
raise InvalidBackboneError(
"Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
else:
2020-03-14 09:56:04 +08:00
return model
2020-02-20 06:12:07 +08:00
def forward(self, x):
2021-01-18 01:12:17 +08:00
return self.backbone(x)