From d0112ed55bbcf7a56c7d57e04dcd982092b66f01 Mon Sep 17 00:00:00 2001 From: Thalles Silva Date: Sun, 17 Jan 2021 20:07:59 -0300 Subject: [PATCH] Major refactor, small fixes --- data_aug/contrastive_learning_dataset.py | 9 +- exceptions/exceptions.py | 6 +- ..._batch_logistic_regression_evaluator.ipynb | 882 +++++++++--------- simclr.py | 25 +- 4 files changed, 473 insertions(+), 449 deletions(-) diff --git a/data_aug/contrastive_learning_dataset.py b/data_aug/contrastive_learning_dataset.py index 7875fcb..e1777e4 100644 --- a/data_aug/contrastive_learning_dataset.py +++ b/data_aug/contrastive_learning_dataset.py @@ -2,6 +2,7 @@ from torchvision.transforms import transforms from data_aug.gaussian_blur import GaussianBlur from torchvision import transforms, datasets from data_aug.view_generator import ContrastiveLearningViewGenerator +from exceptions.exceptions import InvalidDatasetSelection class ContrastiveLearningDataset: @@ -33,5 +34,9 @@ class ContrastiveLearningDataset: n_views), download=True)} - dataset = valid_datasets.get(name, 'Invalid dataset option.')() - return dataset + try: + dataset_fn = valid_datasets[name] + except KeyError: + raise InvalidDatasetSelection() + else: + return dataset_fn() diff --git a/exceptions/exceptions.py b/exceptions/exceptions.py index ae45945..a737084 100644 --- a/exceptions/exceptions.py +++ b/exceptions/exceptions.py @@ -3,4 +3,8 @@ class BaseSimCLRException(Exception): class InvalidBackboneError(BaseSimCLRException): - """Raised when the choice of backbone Convnet is invalid.""" \ No newline at end of file + """Raised when the choice of backbone Convnet is invalid.""" + + +class InvalidDatasetSelection(BaseSimCLRException): + """Raised when the choice of dataset is invalid.""" diff --git a/feature_eval/mini_batch_logistic_regression_evaluator.ipynb b/feature_eval/mini_batch_logistic_regression_evaluator.ipynb index 481f6fa..b3c1f1c 100644 --- a/feature_eval/mini_batch_logistic_regression_evaluator.ipynb +++ b/feature_eval/mini_batch_logistic_regression_evaluator.ipynb @@ -1,444 +1,442 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "display_name": "pytorch", - "language": "python", - "name": "pytorch" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.6" - }, - "colab": { - "name": "mini-batch-logistic-regression-evaluator.ipynb", - "provenance": [], - "include_colab_link": true - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "YUemQib7ZE4D", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import torch\n", - "import sys\n", - "import numpy as np\n", - "import os\n", - "from sklearn.neighbors import KNeighborsClassifier\n", - "import yaml\n", - "import matplotlib.pyplot as plt\n", - "from sklearn.decomposition import PCA\n", - "from sklearn.linear_model import LogisticRegression\n", - "from sklearn import preprocessing\n", - "import importlib.util" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "WSgRE1CcLqdS", - "colab_type": "code", - "colab": {} - }, - "source": [ - "!pip install gdown" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "NOIJEui1ZziV", - "colab_type": "code", - "colab": {} - }, - "source": [ - "def get_file_id_by_model(folder_name):\n", - " file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n", - " 'resnet-18_80-epochs': '1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT',\n", - " 'resnet-50_40-epochs': '1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q',\n", - " 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n", - " 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n", - " return file_id.get(folder_name, \"Model not found.\")" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "G7YMxsvEZMrX", - "colab_type": "code", - "colab": {} - }, - "source": [ - "folder_name = 'resnet-50_40-epochs'\n", - "file_id = get_file_id_by_model(folder_name)\n", - "print(folder_name, file_id)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "PWZ8fet_YoJm", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# download and extract model files\n", - "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n", - "os.system('unzip {}'.format(folder_name))\n", - "!ls" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "3_nypQVEv-hn", - "colab_type": "code", - "colab": {} - }, - "source": [ - "from torch.utils.data import DataLoader\n", - "import torchvision.transforms as transforms\n", - "from torchvision import datasets" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "lDfbL3w_Z0Od", - "colab_type": "code", - "colab": {} - }, - "source": [ - "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", - "print(\"Using device:\", device)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "IQMIryc6LjQd", - "colab_type": "code", - "colab": {} - }, - "source": [ - "checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n", - "config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n", - "config" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "BfIPl0G6_RrT", - "colab_type": "code", - "colab": {} - }, - "source": [ - "def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n", - " train_dataset = datasets.STL10('./data', split='train', download=download,\n", - " transform=transforms.ToTensor())\n", - "\n", - " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n", - " num_workers=0, drop_last=False, shuffle=shuffle)\n", - " \n", - " test_dataset = datasets.STL10('./data', split='test', download=download,\n", - " transform=transforms.ToTensor())\n", - "\n", - " test_loader = DataLoader(test_dataset, batch_size=batch_size,\n", - " num_workers=0, drop_last=False, shuffle=shuffle)\n", - " return train_loader, test_loader" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "a18lPD-tIle6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "def _load_resnet_model(checkpoints_folder):\n", - " # Load the neural net module\n", - " spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n", - " resnet_module = importlib.util.module_from_spec(spec)\n", - " spec.loader.exec_module(resnet_module)\n", - "\n", - " model = resnet_module.ResNetSimCLR(**config['model'])\n", - " model.eval()\n", - "\n", - " state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n", - " model.load_state_dict(state_dict)\n", - " model = model.to(device)\n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5nf4rDtWLjRE", - "colab_type": "text" - }, - "source": [ - "## Protocol #2 Logisitc Regression" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "7jjSxmDnHNQz", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class ResNetFeatureExtractor(object):\n", - " def __init__(self, checkpoints_folder):\n", - " self.checkpoints_folder = checkpoints_folder\n", - " self.model = _load_resnet_model(checkpoints_folder)\n", - "\n", - " def _inference(self, loader):\n", - " feature_vector = []\n", - " labels_vector = []\n", - " for batch_x, batch_y in loader:\n", - "\n", - " batch_x = batch_x.to(device)\n", - " labels_vector.extend(batch_y)\n", - "\n", - " features, _ = self.model(batch_x)\n", - " feature_vector.extend(features.cpu().detach().numpy())\n", - "\n", - " feature_vector = np.array(feature_vector)\n", - " labels_vector = np.array(labels_vector)\n", - "\n", - " print(\"Features shape {}\".format(feature_vector.shape))\n", - " return feature_vector, labels_vector\n", - "\n", - " def get_resnet_features(self):\n", - " train_loader, test_loader = get_stl10_data_loaders(download=True)\n", - " X_train_feature, y_train = self._inference(train_loader)\n", - " X_test_feature, y_test = self._inference(test_loader)\n", - "\n", - " return X_train_feature, y_train, X_test_feature, y_test" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "kghx1govJq5_", - "colab_type": "code", - "colab": {} - }, - "source": [ - "resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "S_JcznxVJ1Xj", - "colab_type": "code", - "colab": {} - }, - "source": [ - "X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "oftbHXcdLjRM", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import torch.nn as nn\n", - "\n", - "class LogisticRegression(nn.Module):\n", - " \n", - " def __init__(self, n_features, n_classes):\n", - " super(LogisticRegression, self).__init__()\n", - " self.model = nn.Linear(n_features, n_classes)\n", - "\n", - " def forward(self, x):\n", - " return self.model(x)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "Ks73ePLtNWeV", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class LogiticRegressionEvaluator(object):\n", - " def __init__(self, n_features, n_classes):\n", - " self.log_regression = LogisticRegression(n_features, n_classes).to(device)\n", - " self.scaler = preprocessing.StandardScaler()\n", - "\n", - " def _normalize_dataset(self, X_train, X_test):\n", - " print(\"Standard Scaling Normalizer\")\n", - " self.scaler.fit(X_train)\n", - " X_train = self.scaler.transform(X_train)\n", - " X_test = self.scaler.transform(X_test)\n", - " return X_train, X_test\n", - "\n", - " @staticmethod\n", - " def _sample_weight_decay():\n", - " # We selected the l2 regularization parameter from a range of 45 logarithmically spaced values between 10−6 and 105\n", - " weight_decay = np.logspace(-6, 5, num=45, base=10.0)\n", - " weight_decay = np.random.choice(weight_decay)\n", - " print(\"Sampled weight decay:\", weight_decay)\n", - " return weight_decay\n", - "\n", - " def eval(self, test_loader):\n", - " correct = 0\n", - " total = 0\n", - "\n", - " with torch.no_grad():\n", - " self.log_regression.eval()\n", - " for batch_x, batch_y in test_loader:\n", - " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n", - " logits = self.log_regression(batch_x)\n", - "\n", - " predicted = torch.argmax(logits, dim=1)\n", - " total += batch_y.size(0)\n", - " correct += (predicted == batch_y).sum().item()\n", - "\n", - " final_acc = 100 * correct / total\n", - " self.log_regression.train()\n", - " return final_acc\n", - "\n", - "\n", - " def create_data_loaders_from_arrays(self, X_train, y_train, X_test, y_test):\n", - " X_train, X_test = self._normalize_dataset(X_train, X_test)\n", - "\n", - " train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train).type(torch.long))\n", - " train_loader = torch.utils.data.DataLoader(train, batch_size=396, shuffle=False)\n", - "\n", - " test = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test).type(torch.long))\n", - " test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False)\n", - " return train_loader, test_loader\n", - "\n", - " def train(self, X_train, y_train, X_test, y_test):\n", - " \n", - " train_loader, test_loader = self.create_data_loaders_from_arrays(X_train, y_train, X_test, y_test)\n", - "\n", - " weight_decay = self._sample_weight_decay()\n", - "\n", - " optimizer = torch.optim.Adam(self.log_regression.parameters(), 3e-4, weight_decay=weight_decay)\n", - " criterion = torch.nn.CrossEntropyLoss()\n", - "\n", - " best_accuracy = 0\n", - "\n", - " for e in range(200):\n", - " \n", - " for batch_x, batch_y in train_loader:\n", - "\n", - " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - "\n", - " logits = self.log_regression(batch_x)\n", - "\n", - " loss = criterion(logits, batch_y)\n", - "\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " epoch_acc = self.eval(test_loader)\n", - " \n", - " if epoch_acc > best_accuracy:\n", - " #print(\"Saving new model with accuracy {}\".format(epoch_acc))\n", - " best_accuracy = epoch_acc\n", - " torch.save(self.log_regression.state_dict(), 'log_regression.pth')\n", - "\n", - " print(\"--------------\")\n", - " print(\"Done training\")\n", - " print(\"Best accuracy:\", best_accuracy)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "NE716m7SOkaK", - "colab_type": "code", - "colab": {} - }, - "source": [ - "log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n", - "\n", - "log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "_GC0a14uWRr6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} \ No newline at end of file + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "YUemQib7ZE4D" + }, + "outputs": [], + "source": [ + "import torch\n", + "import sys\n", + "import numpy as np\n", + "import os\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "import yaml\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.decomposition import PCA\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn import preprocessing\n", + "import importlib.util" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WSgRE1CcLqdS" + }, + "outputs": [], + "source": [ + "!pip install gdown" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NOIJEui1ZziV" + }, + "outputs": [], + "source": [ + "def get_file_id_by_model(folder_name):\n", + " file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n", + " 'resnet-18_80-epochs': '1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT',\n", + " 'resnet-50_40-epochs': '1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q',\n", + " 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n", + " 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n", + " return file_id.get(folder_name, \"Model not found.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "G7YMxsvEZMrX" + }, + "outputs": [], + "source": [ + "folder_name = 'resnet-50_40-epochs'\n", + "file_id = get_file_id_by_model(folder_name)\n", + "print(folder_name, file_id)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PWZ8fet_YoJm" + }, + "outputs": [], + "source": [ + "# download and extract model files\n", + "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n", + "os.system('unzip {}'.format(folder_name))\n", + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3_nypQVEv-hn" + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "import torchvision.transforms as transforms\n", + "from torchvision import datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lDfbL3w_Z0Od" + }, + "outputs": [], + "source": [ + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "print(\"Using device:\", device)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "IQMIryc6LjQd" + }, + "outputs": [], + "source": [ + "checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n", + "config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n", + "config" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "BfIPl0G6_RrT" + }, + "outputs": [], + "source": [ + "def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n", + " train_dataset = datasets.STL10('./data', split='train', download=download,\n", + " transform=transforms.ToTensor())\n", + "\n", + " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n", + " num_workers=0, drop_last=False, shuffle=shuffle)\n", + " \n", + " test_dataset = datasets.STL10('./data', split='test', download=download,\n", + " transform=transforms.ToTensor())\n", + "\n", + " test_loader = DataLoader(test_dataset, batch_size=batch_size,\n", + " num_workers=0, drop_last=False, shuffle=shuffle)\n", + " return train_loader, test_loader" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "a18lPD-tIle6" + }, + "outputs": [], + "source": [ + "def _load_resnet_model(checkpoints_folder):\n", + " # Load the neural net module\n", + " spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n", + " resnet_module = importlib.util.module_from_spec(spec)\n", + " spec.loader.exec_module(resnet_module)\n", + "\n", + " model = resnet_module.ResNetSimCLR(**config['model'])\n", + " model.eval()\n", + "\n", + " state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n", + " model.load_state_dict(state_dict)\n", + " model = model.to(device)\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "5nf4rDtWLjRE" + }, + "source": [ + "## Protocol #2 Logisitc Regression" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7jjSxmDnHNQz" + }, + "outputs": [], + "source": [ + "class ResNetFeatureExtractor(object):\n", + " def __init__(self, checkpoints_folder):\n", + " self.checkpoints_folder = checkpoints_folder\n", + " self.model = _load_resnet_model(checkpoints_folder)\n", + "\n", + " def _inference(self, loader):\n", + " feature_vector = []\n", + " labels_vector = []\n", + " for batch_x, batch_y in loader:\n", + "\n", + " batch_x = batch_x.to(device)\n", + " labels_vector.extend(batch_y)\n", + "\n", + " features, _ = self.model(batch_x)\n", + " feature_vector.extend(features.cpu().detach().numpy())\n", + "\n", + " feature_vector = np.array(feature_vector)\n", + " labels_vector = np.array(labels_vector)\n", + "\n", + " print(\"Features shape {}\".format(feature_vector.shape))\n", + " return feature_vector, labels_vector\n", + "\n", + " def get_resnet_features(self):\n", + " train_loader, test_loader = get_stl10_data_loaders(download=True)\n", + " X_train_feature, y_train = self._inference(train_loader)\n", + " X_test_feature, y_test = self._inference(test_loader)\n", + "\n", + " return X_train_feature, y_train, X_test_feature, y_test" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kghx1govJq5_" + }, + "outputs": [], + "source": [ + "resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "S_JcznxVJ1Xj" + }, + "outputs": [], + "source": [ + "X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "oftbHXcdLjRM" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "class LogisticRegression(nn.Module):\n", + " \n", + " def __init__(self, n_features, n_classes):\n", + " super(LogisticRegression, self).__init__()\n", + " self.model = nn.Linear(n_features, n_classes)\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ks73ePLtNWeV" + }, + "outputs": [], + "source": [ + "class LogiticRegressionEvaluator(object):\n", + " def __init__(self, n_features, n_classes):\n", + " self.log_regression = LogisticRegression(n_features, n_classes).to(device)\n", + " self.scaler = preprocessing.StandardScaler()\n", + "\n", + " def _normalize_dataset(self, X_train, X_test):\n", + " print(\"Standard Scaling Normalizer\")\n", + " self.scaler.fit(X_train)\n", + " X_train = self.scaler.transform(X_train)\n", + " X_test = self.scaler.transform(X_test)\n", + " return X_train, X_test\n", + "\n", + " @staticmethod\n", + " def _sample_weight_decay():\n", + " # We selected the l2 regularization parameter from a range of 45 logarithmically spaced values between 10−6 and 105\n", + " weight_decay = np.logspace(-6, 5, num=45, base=10.0)\n", + " weight_decay = np.random.choice(weight_decay)\n", + " print(\"Sampled weight decay:\", weight_decay)\n", + " return weight_decay\n", + "\n", + " def eval(self, test_loader):\n", + " correct = 0\n", + " total = 0\n", + "\n", + " with torch.no_grad():\n", + " self.log_regression.eval()\n", + " for batch_x, batch_y in test_loader:\n", + " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n", + " logits = self.log_regression(batch_x)\n", + "\n", + " predicted = torch.argmax(logits, dim=1)\n", + " total += batch_y.size(0)\n", + " correct += (predicted == batch_y).sum().item()\n", + "\n", + " final_acc = 100 * correct / total\n", + " self.log_regression.train()\n", + " return final_acc\n", + "\n", + "\n", + " def create_data_loaders_from_arrays(self, X_train, y_train, X_test, y_test):\n", + " X_train, X_test = self._normalize_dataset(X_train, X_test)\n", + "\n", + " train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train).type(torch.long))\n", + " train_loader = torch.utils.data.DataLoader(train, batch_size=396, shuffle=False)\n", + "\n", + " test = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test).type(torch.long))\n", + " test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False)\n", + " return train_loader, test_loader\n", + "\n", + " def train(self, X_train, y_train, X_test, y_test):\n", + " \n", + " train_loader, test_loader = self.create_data_loaders_from_arrays(X_train, y_train, X_test, y_test)\n", + "\n", + " weight_decay = self._sample_weight_decay()\n", + "\n", + " optimizer = torch.optim.Adam(self.log_regression.parameters(), 3e-4, weight_decay=weight_decay)\n", + " criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + " best_accuracy = 0\n", + "\n", + " for e in range(200):\n", + " \n", + " for batch_x, batch_y in train_loader:\n", + "\n", + " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " logits = self.log_regression(batch_x)\n", + "\n", + " loss = criterion(logits, batch_y)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " epoch_acc = self.eval(test_loader)\n", + " \n", + " if epoch_acc > best_accuracy:\n", + " #print(\"Saving new model with accuracy {}\".format(epoch_acc))\n", + " best_accuracy = epoch_acc\n", + " torch.save(self.log_regression.state_dict(), 'log_regression.pth')\n", + "\n", + " print(\"--------------\")\n", + " print(\"Done training\")\n", + " print(\"Best accuracy:\", best_accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NE716m7SOkaK" + }, + "outputs": [], + "source": [ + "log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n", + "\n", + "log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_GC0a14uWRr6" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "name": "mini-batch-logistic-regression-evaluator.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python [conda env:image]", + "language": "python", + "name": "conda-env-image-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.12" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/simclr.py b/simclr.py index a64c8cb..a41bc88 100644 --- a/simclr.py +++ b/simclr.py @@ -34,6 +34,23 @@ def _save_config_file(model_checkpoints_folder, args): yaml.dump(args, outfile, default_flow_style=False) +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + class SimCLR(object): def __init__(self, *args, **kwargs): @@ -97,10 +114,10 @@ class SimCLR(object): self.optimizer.step() if n_iter % self.args.log_every_n_steps == 0: - predictions = torch.argmax(logits, dim=1) - acc = 100 * (predictions == labels).float().mean() + top1, top5 = accuracy(logits, labels, topk=(1,5)) self.writer.add_scalar('loss', loss, global_step=n_iter) - self.writer.add_scalar('acc/top1', acc, global_step=n_iter) + self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter) + self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter) self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter) n_iter += 1 @@ -108,7 +125,7 @@ class SimCLR(object): # warmup for the first 10 epochs if epoch_counter >= 10: self.scheduler.step() - logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {acc}") + logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}") logging.info("Training has finished.") # save model checkpoints