SimCLR/feature_eval/mini_batch_logistic_regression_evaluator.ipynb
2020-03-15 10:45:25 -03:00

570 lines
19 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"colab": {
"name": "mini-batch-logistic-regression-evaluator.ipynb",
"provenance": [],
"include_colab_link": true
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "YUemQib7ZE4D",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch\n",
"import sys\n",
"import numpy as np\n",
"import os\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"import yaml\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn import preprocessing\n",
"import importlib.util"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WSgRE1CcLqdS",
"colab_type": "code",
"outputId": "3bd80a41-005c-416d-9476-a0dc14921ab0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 163
}
},
"source": [
"!pip install gdown"
],
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.28.1)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.21.0)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.12.0)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2019.11.28)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.8)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NOIJEui1ZziV",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_file_id_by_model(folder_name):\n",
" file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n",
" 'resnet-18_80-epochs': '1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT',\n",
" 'resnet-50_40-epochs': '1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q',\n",
" 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35' }\n",
" return file_id.get(folder_name, \"Model not found.\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "G7YMxsvEZMrX",
"colab_type": "code",
"outputId": "430bc8d7-6e3c-44c5-eb8f-7f3ca24c4172",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"folder_name = 'resnet-50_40-epochs'\n",
"file_id = get_file_id_by_model(folder_name)\n",
"print(folder_name, file_id)"
],
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"text": [
"resnet-50_40-epochs 1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "PWZ8fet_YoJm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 72
},
"outputId": "2871e598-a429-4cfa-cd96-40850003e638"
},
"source": [
"# download and extract model files\n",
"os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
"os.system('unzip {}'.format(folder_name))\n",
"!ls"
],
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"text": [
"data\t\t resnet-50_40-epochs.zip sample_data\n",
"log_regression.pth resnet-50_80-epochs\n",
"resnet-50_40-epochs resnet-50_80-epochs.zip\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3_nypQVEv-hn",
"colab_type": "code",
"colab": {}
},
"source": [
"from torch.utils.data import DataLoader\n",
"import torchvision.transforms as transforms\n",
"from torchvision import datasets"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lDfbL3w_Z0Od",
"colab_type": "code",
"outputId": "d148eb48-8e56-4af5-8c2b-c7821e2c7149",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(\"Using device:\", device)"
],
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"text": [
"Using device: cuda\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "IQMIryc6LjQd",
"colab_type": "code",
"outputId": "9020c91b-d9ad-4d46-c181-cd394061df0d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 217
}
},
"source": [
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
"config"
],
"execution_count": 31,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'batch_size': 256,\n",
" 'dataset': {'input_shape': '(96,96,3)',\n",
" 'num_workers': 0,\n",
" 's': 1,\n",
" 'valid_size': 0.05},\n",
" 'epochs': 40,\n",
" 'eval_every_n_epochs': 1,\n",
" 'fine_tune_from': 'None',\n",
" 'log_every_n_steps': 50,\n",
" 'loss': {'temperature': 0.5, 'use_cosine_similarity': True},\n",
" 'model': {'base_model': 'resnet50', 'out_dim': 128}}"
]
},
"metadata": {
"tags": []
},
"execution_count": 31
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "BfIPl0G6_RrT",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
" transform=transforms.ToTensor())\n",
"\n",
" train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
" \n",
" test_dataset = datasets.STL10('./data', split='test', download=download,\n",
" transform=transforms.ToTensor())\n",
"\n",
" test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
" return train_loader, test_loader"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "a18lPD-tIle6",
"colab_type": "code",
"colab": {}
},
"source": [
"def _load_resnet_model(checkpoints_folder):\n",
" # Load the neural net module\n",
" spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n",
" resnet_module = importlib.util.module_from_spec(spec)\n",
" spec.loader.exec_module(resnet_module)\n",
"\n",
" model = resnet_module.ResNetSimCLR(**config['model'])\n",
" model.eval()\n",
"\n",
" state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n",
" model.load_state_dict(state_dict)\n",
" model = model.to(device)\n",
" return model"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "5nf4rDtWLjRE",
"colab_type": "text"
},
"source": [
"## Protocol #2 Logisitc Regression"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7jjSxmDnHNQz",
"colab_type": "code",
"colab": {}
},
"source": [
"class ResNetFeatureExtractor(object):\n",
" def __init__(self, checkpoints_folder):\n",
" self.checkpoints_folder = checkpoints_folder\n",
" self.model = _load_resnet_model(checkpoints_folder)\n",
"\n",
" def _inference(self, loader):\n",
" feature_vector = []\n",
" labels_vector = []\n",
" for batch_x, batch_y in loader:\n",
"\n",
" batch_x = batch_x.to(device)\n",
" labels_vector.extend(batch_y)\n",
"\n",
" features, _ = self.model(batch_x)\n",
" feature_vector.extend(features.cpu().detach().numpy())\n",
"\n",
" feature_vector = np.array(feature_vector)\n",
" labels_vector = np.array(labels_vector)\n",
"\n",
" print(\"Features shape {}\".format(feature_vector.shape))\n",
" return feature_vector, labels_vector\n",
"\n",
" def get_resnet_features(self):\n",
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
" X_train_feature, y_train = self._inference(train_loader)\n",
" X_test_feature, y_test = self._inference(test_loader)\n",
"\n",
" return X_train_feature, y_train, X_test_feature, y_test"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kghx1govJq5_",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "36040306-9730-4781-eaef-dc9018e75176"
},
"source": [
"resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
],
"execution_count": 35,
"outputs": [
{
"output_type": "stream",
"text": [
"Feature extractor: resnet50\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "S_JcznxVJ1Xj",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"outputId": "aea5aa1c-d78c-4df2-86da-97efa484e093"
},
"source": [
"X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
],
"execution_count": 36,
"outputs": [
{
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n",
"Features shape (5000, 2048)\n",
"Features shape (8000, 2048)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "oftbHXcdLjRM",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch.nn as nn\n",
"\n",
"class LogisticRegression(nn.Module):\n",
" \n",
" def __init__(self, n_features, n_classes):\n",
" super(LogisticRegression, self).__init__()\n",
" self.model = nn.Linear(n_features, n_classes)\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ks73ePLtNWeV",
"colab_type": "code",
"colab": {}
},
"source": [
"class LogiticRegressionEvaluator(object):\n",
" def __init__(self, n_features, n_classes):\n",
" self.log_regression = LogisticRegression(n_features, n_classes).to(device)\n",
" self.scaler = preprocessing.StandardScaler()\n",
"\n",
" def _normalize_dataset(self, X_train, X_test):\n",
" print(\"Standard Scaling Normalizer\")\n",
" self.scaler.fit(X_train)\n",
" X_train = self.scaler.transform(X_train)\n",
" X_test = self.scaler.transform(X_test)\n",
" return X_train, X_test\n",
"\n",
" @staticmethod\n",
" def _sample_weight_decay():\n",
" # We selected the l2 regularization parameter from a range of 45 logarithmically spaced values between 106 and 105\n",
" weight_decay = np.logspace(-6, 5, num=45, base=10.0)\n",
" weight_decay = np.random.choice(weight_decay)\n",
" print(\"Sampled weight decay:\", weight_decay)\n",
" return weight_decay\n",
"\n",
" def eval(self, test_loader):\n",
" correct = 0\n",
" total = 0\n",
"\n",
" with torch.no_grad():\n",
" self.log_regression.eval()\n",
" for batch_x, batch_y in test_loader:\n",
" batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
" logits = self.log_regression(batch_x)\n",
"\n",
" predicted = torch.argmax(logits, dim=1)\n",
" total += batch_y.size(0)\n",
" correct += (predicted == batch_y).sum().item()\n",
"\n",
" final_acc = 100 * correct / total\n",
" self.log_regression.train()\n",
" return final_acc\n",
"\n",
"\n",
" def create_data_loaders_from_arrays(self, X_train, y_train, X_test, y_test):\n",
" X_train, X_test = self._normalize_dataset(X_train, X_test)\n",
"\n",
" train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train).type(torch.long))\n",
" train_loader = torch.utils.data.DataLoader(train, batch_size=396, shuffle=False)\n",
"\n",
" test = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test).type(torch.long))\n",
" test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False)\n",
" return train_loader, test_loader\n",
"\n",
" def train(self, X_train, y_train, X_test, y_test):\n",
" \n",
" train_loader, test_loader = self.create_data_loaders_from_arrays(X_train, y_train, X_test, y_test)\n",
"\n",
" weight_decay = self._sample_weight_decay()\n",
"\n",
" optimizer = torch.optim.Adam(self.log_regression.parameters(), 3e-4, weight_decay=weight_decay)\n",
" criterion = torch.nn.CrossEntropyLoss()\n",
"\n",
" best_accuracy = 0\n",
"\n",
" for e in range(200):\n",
" \n",
" for batch_x, batch_y in train_loader:\n",
"\n",
" batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" logits = self.log_regression(batch_x)\n",
"\n",
" loss = criterion(logits, batch_y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" epoch_acc = self.eval(test_loader)\n",
" \n",
" if epoch_acc > best_accuracy:\n",
" #print(\"Saving new model with accuracy {}\".format(epoch_acc))\n",
" best_accuracy = epoch_acc\n",
" torch.save(self.log_regression.state_dict(), 'log_regression.pth')\n",
"\n",
" print(\"--------------\")\n",
" print(\"Done training\")\n",
" print(\"Best accuracy:\", best_accuracy)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NE716m7SOkaK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
},
"outputId": "87de8f71-4312-4d76-a1f3-7af5bbe0e9ba"
},
"source": [
"log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
"\n",
"log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
],
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"text": [
"Standard Scaling Normalizer\n",
"Sampled weight decay: 0.00017782794100389227\n",
"--------------\n",
"Done training\n",
"Best accuracy: 73.6625\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_GC0a14uWRr6",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}