diff --git a/data_aug/contrastive_learning_dataset.py b/data_aug/contrastive_learning_dataset.py
index 7875fcb..e1777e4 100644
--- a/data_aug/contrastive_learning_dataset.py
+++ b/data_aug/contrastive_learning_dataset.py
@@ -2,6 +2,7 @@ from torchvision.transforms import transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import transforms, datasets
from data_aug.view_generator import ContrastiveLearningViewGenerator
+from exceptions.exceptions import InvalidDatasetSelection
class ContrastiveLearningDataset:
@@ -33,5 +34,9 @@ class ContrastiveLearningDataset:
n_views),
download=True)}
- dataset = valid_datasets.get(name, 'Invalid dataset option.')()
- return dataset
+ try:
+ dataset_fn = valid_datasets[name]
+ except KeyError:
+ raise InvalidDatasetSelection()
+ else:
+ return dataset_fn()
diff --git a/exceptions/exceptions.py b/exceptions/exceptions.py
index ae45945..a737084 100644
--- a/exceptions/exceptions.py
+++ b/exceptions/exceptions.py
@@ -3,4 +3,8 @@ class BaseSimCLRException(Exception):
class InvalidBackboneError(BaseSimCLRException):
- """Raised when the choice of backbone Convnet is invalid."""
\ No newline at end of file
+ """Raised when the choice of backbone Convnet is invalid."""
+
+
+class InvalidDatasetSelection(BaseSimCLRException):
+ """Raised when the choice of dataset is invalid."""
diff --git a/feature_eval/mini_batch_logistic_regression_evaluator.ipynb b/feature_eval/mini_batch_logistic_regression_evaluator.ipynb
index 481f6fa..b3c1f1c 100644
--- a/feature_eval/mini_batch_logistic_regression_evaluator.ipynb
+++ b/feature_eval/mini_batch_logistic_regression_evaluator.ipynb
@@ -1,444 +1,442 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "kernelspec": {
- "display_name": "pytorch",
- "language": "python",
- "name": "pytorch"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.6"
- },
- "colab": {
- "name": "mini-batch-logistic-regression-evaluator.ipynb",
- "provenance": [],
- "include_colab_link": true
- },
- "accelerator": "GPU"
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ "
"
+ ]
},
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- "
"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "YUemQib7ZE4D",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "import torch\n",
- "import sys\n",
- "import numpy as np\n",
- "import os\n",
- "from sklearn.neighbors import KNeighborsClassifier\n",
- "import yaml\n",
- "import matplotlib.pyplot as plt\n",
- "from sklearn.decomposition import PCA\n",
- "from sklearn.linear_model import LogisticRegression\n",
- "from sklearn import preprocessing\n",
- "import importlib.util"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "WSgRE1CcLqdS",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "!pip install gdown"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "NOIJEui1ZziV",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "def get_file_id_by_model(folder_name):\n",
- " file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n",
- " 'resnet-18_80-epochs': '1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT',\n",
- " 'resnet-50_40-epochs': '1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q',\n",
- " 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n",
- " 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n",
- " return file_id.get(folder_name, \"Model not found.\")"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "G7YMxsvEZMrX",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "folder_name = 'resnet-50_40-epochs'\n",
- "file_id = get_file_id_by_model(folder_name)\n",
- "print(folder_name, file_id)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "PWZ8fet_YoJm",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "# download and extract model files\n",
- "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
- "os.system('unzip {}'.format(folder_name))\n",
- "!ls"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "3_nypQVEv-hn",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "from torch.utils.data import DataLoader\n",
- "import torchvision.transforms as transforms\n",
- "from torchvision import datasets"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "lDfbL3w_Z0Od",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
- "print(\"Using device:\", device)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "IQMIryc6LjQd",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
- "config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
- "config"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "BfIPl0G6_RrT",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
- " train_dataset = datasets.STL10('./data', split='train', download=download,\n",
- " transform=transforms.ToTensor())\n",
- "\n",
- " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
- " num_workers=0, drop_last=False, shuffle=shuffle)\n",
- " \n",
- " test_dataset = datasets.STL10('./data', split='test', download=download,\n",
- " transform=transforms.ToTensor())\n",
- "\n",
- " test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
- " num_workers=0, drop_last=False, shuffle=shuffle)\n",
- " return train_loader, test_loader"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "a18lPD-tIle6",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "def _load_resnet_model(checkpoints_folder):\n",
- " # Load the neural net module\n",
- " spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n",
- " resnet_module = importlib.util.module_from_spec(spec)\n",
- " spec.loader.exec_module(resnet_module)\n",
- "\n",
- " model = resnet_module.ResNetSimCLR(**config['model'])\n",
- " model.eval()\n",
- "\n",
- " state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n",
- " model.load_state_dict(state_dict)\n",
- " model = model.to(device)\n",
- " return model"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5nf4rDtWLjRE",
- "colab_type": "text"
- },
- "source": [
- "## Protocol #2 Logisitc Regression"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "7jjSxmDnHNQz",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "class ResNetFeatureExtractor(object):\n",
- " def __init__(self, checkpoints_folder):\n",
- " self.checkpoints_folder = checkpoints_folder\n",
- " self.model = _load_resnet_model(checkpoints_folder)\n",
- "\n",
- " def _inference(self, loader):\n",
- " feature_vector = []\n",
- " labels_vector = []\n",
- " for batch_x, batch_y in loader:\n",
- "\n",
- " batch_x = batch_x.to(device)\n",
- " labels_vector.extend(batch_y)\n",
- "\n",
- " features, _ = self.model(batch_x)\n",
- " feature_vector.extend(features.cpu().detach().numpy())\n",
- "\n",
- " feature_vector = np.array(feature_vector)\n",
- " labels_vector = np.array(labels_vector)\n",
- "\n",
- " print(\"Features shape {}\".format(feature_vector.shape))\n",
- " return feature_vector, labels_vector\n",
- "\n",
- " def get_resnet_features(self):\n",
- " train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
- " X_train_feature, y_train = self._inference(train_loader)\n",
- " X_test_feature, y_test = self._inference(test_loader)\n",
- "\n",
- " return X_train_feature, y_train, X_test_feature, y_test"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "kghx1govJq5_",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "S_JcznxVJ1Xj",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "oftbHXcdLjRM",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "import torch.nn as nn\n",
- "\n",
- "class LogisticRegression(nn.Module):\n",
- " \n",
- " def __init__(self, n_features, n_classes):\n",
- " super(LogisticRegression, self).__init__()\n",
- " self.model = nn.Linear(n_features, n_classes)\n",
- "\n",
- " def forward(self, x):\n",
- " return self.model(x)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "Ks73ePLtNWeV",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "class LogiticRegressionEvaluator(object):\n",
- " def __init__(self, n_features, n_classes):\n",
- " self.log_regression = LogisticRegression(n_features, n_classes).to(device)\n",
- " self.scaler = preprocessing.StandardScaler()\n",
- "\n",
- " def _normalize_dataset(self, X_train, X_test):\n",
- " print(\"Standard Scaling Normalizer\")\n",
- " self.scaler.fit(X_train)\n",
- " X_train = self.scaler.transform(X_train)\n",
- " X_test = self.scaler.transform(X_test)\n",
- " return X_train, X_test\n",
- "\n",
- " @staticmethod\n",
- " def _sample_weight_decay():\n",
- " # We selected the l2 regularization parameter from a range of 45 logarithmically spaced values between 10−6 and 105\n",
- " weight_decay = np.logspace(-6, 5, num=45, base=10.0)\n",
- " weight_decay = np.random.choice(weight_decay)\n",
- " print(\"Sampled weight decay:\", weight_decay)\n",
- " return weight_decay\n",
- "\n",
- " def eval(self, test_loader):\n",
- " correct = 0\n",
- " total = 0\n",
- "\n",
- " with torch.no_grad():\n",
- " self.log_regression.eval()\n",
- " for batch_x, batch_y in test_loader:\n",
- " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
- " logits = self.log_regression(batch_x)\n",
- "\n",
- " predicted = torch.argmax(logits, dim=1)\n",
- " total += batch_y.size(0)\n",
- " correct += (predicted == batch_y).sum().item()\n",
- "\n",
- " final_acc = 100 * correct / total\n",
- " self.log_regression.train()\n",
- " return final_acc\n",
- "\n",
- "\n",
- " def create_data_loaders_from_arrays(self, X_train, y_train, X_test, y_test):\n",
- " X_train, X_test = self._normalize_dataset(X_train, X_test)\n",
- "\n",
- " train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train).type(torch.long))\n",
- " train_loader = torch.utils.data.DataLoader(train, batch_size=396, shuffle=False)\n",
- "\n",
- " test = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test).type(torch.long))\n",
- " test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False)\n",
- " return train_loader, test_loader\n",
- "\n",
- " def train(self, X_train, y_train, X_test, y_test):\n",
- " \n",
- " train_loader, test_loader = self.create_data_loaders_from_arrays(X_train, y_train, X_test, y_test)\n",
- "\n",
- " weight_decay = self._sample_weight_decay()\n",
- "\n",
- " optimizer = torch.optim.Adam(self.log_regression.parameters(), 3e-4, weight_decay=weight_decay)\n",
- " criterion = torch.nn.CrossEntropyLoss()\n",
- "\n",
- " best_accuracy = 0\n",
- "\n",
- " for e in range(200):\n",
- " \n",
- " for batch_x, batch_y in train_loader:\n",
- "\n",
- " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
- "\n",
- " optimizer.zero_grad()\n",
- "\n",
- " logits = self.log_regression(batch_x)\n",
- "\n",
- " loss = criterion(logits, batch_y)\n",
- "\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- "\n",
- " epoch_acc = self.eval(test_loader)\n",
- " \n",
- " if epoch_acc > best_accuracy:\n",
- " #print(\"Saving new model with accuracy {}\".format(epoch_acc))\n",
- " best_accuracy = epoch_acc\n",
- " torch.save(self.log_regression.state_dict(), 'log_regression.pth')\n",
- "\n",
- " print(\"--------------\")\n",
- " print(\"Done training\")\n",
- " print(\"Best accuracy:\", best_accuracy)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "NE716m7SOkaK",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- "log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
- "\n",
- "log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "_GC0a14uWRr6",
- "colab_type": "code",
- "colab": {}
- },
- "source": [
- ""
- ],
- "execution_count": 0,
- "outputs": []
- }
- ]
-}
\ No newline at end of file
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "YUemQib7ZE4D"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import sys\n",
+ "import numpy as np\n",
+ "import os\n",
+ "from sklearn.neighbors import KNeighborsClassifier\n",
+ "import yaml\n",
+ "import matplotlib.pyplot as plt\n",
+ "from sklearn.decomposition import PCA\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn import preprocessing\n",
+ "import importlib.util"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "WSgRE1CcLqdS"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install gdown"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "NOIJEui1ZziV"
+ },
+ "outputs": [],
+ "source": [
+ "def get_file_id_by_model(folder_name):\n",
+ " file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n",
+ " 'resnet-18_80-epochs': '1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT',\n",
+ " 'resnet-50_40-epochs': '1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q',\n",
+ " 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n",
+ " 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n",
+ " return file_id.get(folder_name, \"Model not found.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "G7YMxsvEZMrX"
+ },
+ "outputs": [],
+ "source": [
+ "folder_name = 'resnet-50_40-epochs'\n",
+ "file_id = get_file_id_by_model(folder_name)\n",
+ "print(folder_name, file_id)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "PWZ8fet_YoJm"
+ },
+ "outputs": [],
+ "source": [
+ "# download and extract model files\n",
+ "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
+ "os.system('unzip {}'.format(folder_name))\n",
+ "!ls"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "3_nypQVEv-hn"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.data import DataLoader\n",
+ "import torchvision.transforms as transforms\n",
+ "from torchvision import datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "lDfbL3w_Z0Od"
+ },
+ "outputs": [],
+ "source": [
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "print(\"Using device:\", device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "IQMIryc6LjQd"
+ },
+ "outputs": [],
+ "source": [
+ "checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
+ "config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
+ "config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "BfIPl0G6_RrT"
+ },
+ "outputs": [],
+ "source": [
+ "def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
+ " train_dataset = datasets.STL10('./data', split='train', download=download,\n",
+ " transform=transforms.ToTensor())\n",
+ "\n",
+ " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
+ " num_workers=0, drop_last=False, shuffle=shuffle)\n",
+ " \n",
+ " test_dataset = datasets.STL10('./data', split='test', download=download,\n",
+ " transform=transforms.ToTensor())\n",
+ "\n",
+ " test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
+ " num_workers=0, drop_last=False, shuffle=shuffle)\n",
+ " return train_loader, test_loader"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "a18lPD-tIle6"
+ },
+ "outputs": [],
+ "source": [
+ "def _load_resnet_model(checkpoints_folder):\n",
+ " # Load the neural net module\n",
+ " spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n",
+ " resnet_module = importlib.util.module_from_spec(spec)\n",
+ " spec.loader.exec_module(resnet_module)\n",
+ "\n",
+ " model = resnet_module.ResNetSimCLR(**config['model'])\n",
+ " model.eval()\n",
+ "\n",
+ " state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n",
+ " model.load_state_dict(state_dict)\n",
+ " model = model.to(device)\n",
+ " return model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "5nf4rDtWLjRE"
+ },
+ "source": [
+ "## Protocol #2 Logisitc Regression"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "7jjSxmDnHNQz"
+ },
+ "outputs": [],
+ "source": [
+ "class ResNetFeatureExtractor(object):\n",
+ " def __init__(self, checkpoints_folder):\n",
+ " self.checkpoints_folder = checkpoints_folder\n",
+ " self.model = _load_resnet_model(checkpoints_folder)\n",
+ "\n",
+ " def _inference(self, loader):\n",
+ " feature_vector = []\n",
+ " labels_vector = []\n",
+ " for batch_x, batch_y in loader:\n",
+ "\n",
+ " batch_x = batch_x.to(device)\n",
+ " labels_vector.extend(batch_y)\n",
+ "\n",
+ " features, _ = self.model(batch_x)\n",
+ " feature_vector.extend(features.cpu().detach().numpy())\n",
+ "\n",
+ " feature_vector = np.array(feature_vector)\n",
+ " labels_vector = np.array(labels_vector)\n",
+ "\n",
+ " print(\"Features shape {}\".format(feature_vector.shape))\n",
+ " return feature_vector, labels_vector\n",
+ "\n",
+ " def get_resnet_features(self):\n",
+ " train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
+ " X_train_feature, y_train = self._inference(train_loader)\n",
+ " X_test_feature, y_test = self._inference(test_loader)\n",
+ "\n",
+ " return X_train_feature, y_train, X_test_feature, y_test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "kghx1govJq5_"
+ },
+ "outputs": [],
+ "source": [
+ "resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "S_JcznxVJ1Xj"
+ },
+ "outputs": [],
+ "source": [
+ "X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "oftbHXcdLjRM"
+ },
+ "outputs": [],
+ "source": [
+ "import torch.nn as nn\n",
+ "\n",
+ "class LogisticRegression(nn.Module):\n",
+ " \n",
+ " def __init__(self, n_features, n_classes):\n",
+ " super(LogisticRegression, self).__init__()\n",
+ " self.model = nn.Linear(n_features, n_classes)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return self.model(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Ks73ePLtNWeV"
+ },
+ "outputs": [],
+ "source": [
+ "class LogiticRegressionEvaluator(object):\n",
+ " def __init__(self, n_features, n_classes):\n",
+ " self.log_regression = LogisticRegression(n_features, n_classes).to(device)\n",
+ " self.scaler = preprocessing.StandardScaler()\n",
+ "\n",
+ " def _normalize_dataset(self, X_train, X_test):\n",
+ " print(\"Standard Scaling Normalizer\")\n",
+ " self.scaler.fit(X_train)\n",
+ " X_train = self.scaler.transform(X_train)\n",
+ " X_test = self.scaler.transform(X_test)\n",
+ " return X_train, X_test\n",
+ "\n",
+ " @staticmethod\n",
+ " def _sample_weight_decay():\n",
+ " # We selected the l2 regularization parameter from a range of 45 logarithmically spaced values between 10−6 and 105\n",
+ " weight_decay = np.logspace(-6, 5, num=45, base=10.0)\n",
+ " weight_decay = np.random.choice(weight_decay)\n",
+ " print(\"Sampled weight decay:\", weight_decay)\n",
+ " return weight_decay\n",
+ "\n",
+ " def eval(self, test_loader):\n",
+ " correct = 0\n",
+ " total = 0\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " self.log_regression.eval()\n",
+ " for batch_x, batch_y in test_loader:\n",
+ " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
+ " logits = self.log_regression(batch_x)\n",
+ "\n",
+ " predicted = torch.argmax(logits, dim=1)\n",
+ " total += batch_y.size(0)\n",
+ " correct += (predicted == batch_y).sum().item()\n",
+ "\n",
+ " final_acc = 100 * correct / total\n",
+ " self.log_regression.train()\n",
+ " return final_acc\n",
+ "\n",
+ "\n",
+ " def create_data_loaders_from_arrays(self, X_train, y_train, X_test, y_test):\n",
+ " X_train, X_test = self._normalize_dataset(X_train, X_test)\n",
+ "\n",
+ " train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train).type(torch.long))\n",
+ " train_loader = torch.utils.data.DataLoader(train, batch_size=396, shuffle=False)\n",
+ "\n",
+ " test = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test).type(torch.long))\n",
+ " test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False)\n",
+ " return train_loader, test_loader\n",
+ "\n",
+ " def train(self, X_train, y_train, X_test, y_test):\n",
+ " \n",
+ " train_loader, test_loader = self.create_data_loaders_from_arrays(X_train, y_train, X_test, y_test)\n",
+ "\n",
+ " weight_decay = self._sample_weight_decay()\n",
+ "\n",
+ " optimizer = torch.optim.Adam(self.log_regression.parameters(), 3e-4, weight_decay=weight_decay)\n",
+ " criterion = torch.nn.CrossEntropyLoss()\n",
+ "\n",
+ " best_accuracy = 0\n",
+ "\n",
+ " for e in range(200):\n",
+ " \n",
+ " for batch_x, batch_y in train_loader:\n",
+ "\n",
+ " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
+ "\n",
+ " optimizer.zero_grad()\n",
+ "\n",
+ " logits = self.log_regression(batch_x)\n",
+ "\n",
+ " loss = criterion(logits, batch_y)\n",
+ "\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " epoch_acc = self.eval(test_loader)\n",
+ " \n",
+ " if epoch_acc > best_accuracy:\n",
+ " #print(\"Saving new model with accuracy {}\".format(epoch_acc))\n",
+ " best_accuracy = epoch_acc\n",
+ " torch.save(self.log_regression.state_dict(), 'log_regression.pth')\n",
+ "\n",
+ " print(\"--------------\")\n",
+ " print(\"Done training\")\n",
+ " print(\"Best accuracy:\", best_accuracy)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "NE716m7SOkaK"
+ },
+ "outputs": [],
+ "source": [
+ "log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
+ "\n",
+ "log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "_GC0a14uWRr6"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "name": "mini-batch-logistic-regression-evaluator.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python [conda env:image]",
+ "language": "python",
+ "name": "conda-env-image-py"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/simclr.py b/simclr.py
index a64c8cb..a41bc88 100644
--- a/simclr.py
+++ b/simclr.py
@@ -34,6 +34,23 @@ def _save_config_file(model_checkpoints_folder, args):
yaml.dump(args, outfile, default_flow_style=False)
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ with torch.no_grad():
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
class SimCLR(object):
def __init__(self, *args, **kwargs):
@@ -97,10 +114,10 @@ class SimCLR(object):
self.optimizer.step()
if n_iter % self.args.log_every_n_steps == 0:
- predictions = torch.argmax(logits, dim=1)
- acc = 100 * (predictions == labels).float().mean()
+ top1, top5 = accuracy(logits, labels, topk=(1,5))
self.writer.add_scalar('loss', loss, global_step=n_iter)
- self.writer.add_scalar('acc/top1', acc, global_step=n_iter)
+ self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
+ self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
n_iter += 1
@@ -108,7 +125,7 @@ class SimCLR(object):
# warmup for the first 10 epochs
if epoch_counter >= 10:
self.scheduler.step()
- logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {acc}")
+ logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
logging.info("Training has finished.")
# save model checkpoints