mirror of https://github.com/sthalles/SimCLR.git
444 lines
14 KiB
Plaintext
444 lines
14 KiB
Plaintext
{
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0,
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "pytorch",
|
||
"language": "python",
|
||
"name": "pytorch"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.6"
|
||
},
|
||
"colab": {
|
||
"name": "mini-batch-logistic-regression-evaluator.ipynb",
|
||
"provenance": [],
|
||
"include_colab_link": true
|
||
},
|
||
"accelerator": "GPU"
|
||
},
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "view-in-github",
|
||
"colab_type": "text"
|
||
},
|
||
"source": [
|
||
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "YUemQib7ZE4D",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"import torch\n",
|
||
"import sys\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
"import yaml\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from sklearn.decomposition import PCA\n",
|
||
"from sklearn.linear_model import LogisticRegression\n",
|
||
"from sklearn import preprocessing\n",
|
||
"import importlib.util"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "WSgRE1CcLqdS",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"!pip install gdown"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "NOIJEui1ZziV",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"def get_file_id_by_model(folder_name):\n",
|
||
" file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n",
|
||
" 'resnet-18_80-epochs': '1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT',\n",
|
||
" 'resnet-50_40-epochs': '1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q',\n",
|
||
" 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n",
|
||
" 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n",
|
||
" return file_id.get(folder_name, \"Model not found.\")"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "G7YMxsvEZMrX",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"folder_name = 'resnet-50_40-epochs'\n",
|
||
"file_id = get_file_id_by_model(folder_name)\n",
|
||
"print(folder_name, file_id)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "PWZ8fet_YoJm",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"# download and extract model files\n",
|
||
"os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
|
||
"os.system('unzip {}'.format(folder_name))\n",
|
||
"!ls"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "3_nypQVEv-hn",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"from torch.utils.data import DataLoader\n",
|
||
"import torchvision.transforms as transforms\n",
|
||
"from torchvision import datasets"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "lDfbL3w_Z0Od",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||
"print(\"Using device:\", device)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "IQMIryc6LjQd",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
|
||
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
|
||
"config"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "BfIPl0G6_RrT",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
|
||
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
|
||
" transform=transforms.ToTensor())\n",
|
||
"\n",
|
||
" train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
|
||
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
|
||
" \n",
|
||
" test_dataset = datasets.STL10('./data', split='test', download=download,\n",
|
||
" transform=transforms.ToTensor())\n",
|
||
"\n",
|
||
" test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
|
||
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
|
||
" return train_loader, test_loader"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "a18lPD-tIle6",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"def _load_resnet_model(checkpoints_folder):\n",
|
||
" # Load the neural net module\n",
|
||
" spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n",
|
||
" resnet_module = importlib.util.module_from_spec(spec)\n",
|
||
" spec.loader.exec_module(resnet_module)\n",
|
||
"\n",
|
||
" model = resnet_module.ResNetSimCLR(**config['model'])\n",
|
||
" model.eval()\n",
|
||
"\n",
|
||
" state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n",
|
||
" model.load_state_dict(state_dict)\n",
|
||
" model = model.to(device)\n",
|
||
" return model"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "5nf4rDtWLjRE",
|
||
"colab_type": "text"
|
||
},
|
||
"source": [
|
||
"## Protocol #2 Logisitc Regression"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "7jjSxmDnHNQz",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"class ResNetFeatureExtractor(object):\n",
|
||
" def __init__(self, checkpoints_folder):\n",
|
||
" self.checkpoints_folder = checkpoints_folder\n",
|
||
" self.model = _load_resnet_model(checkpoints_folder)\n",
|
||
"\n",
|
||
" def _inference(self, loader):\n",
|
||
" feature_vector = []\n",
|
||
" labels_vector = []\n",
|
||
" for batch_x, batch_y in loader:\n",
|
||
"\n",
|
||
" batch_x = batch_x.to(device)\n",
|
||
" labels_vector.extend(batch_y)\n",
|
||
"\n",
|
||
" features, _ = self.model(batch_x)\n",
|
||
" feature_vector.extend(features.cpu().detach().numpy())\n",
|
||
"\n",
|
||
" feature_vector = np.array(feature_vector)\n",
|
||
" labels_vector = np.array(labels_vector)\n",
|
||
"\n",
|
||
" print(\"Features shape {}\".format(feature_vector.shape))\n",
|
||
" return feature_vector, labels_vector\n",
|
||
"\n",
|
||
" def get_resnet_features(self):\n",
|
||
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
|
||
" X_train_feature, y_train = self._inference(train_loader)\n",
|
||
" X_test_feature, y_test = self._inference(test_loader)\n",
|
||
"\n",
|
||
" return X_train_feature, y_train, X_test_feature, y_test"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "kghx1govJq5_",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "S_JcznxVJ1Xj",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "oftbHXcdLjRM",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"import torch.nn as nn\n",
|
||
"\n",
|
||
"class LogisticRegression(nn.Module):\n",
|
||
" \n",
|
||
" def __init__(self, n_features, n_classes):\n",
|
||
" super(LogisticRegression, self).__init__()\n",
|
||
" self.model = nn.Linear(n_features, n_classes)\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" return self.model(x)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "Ks73ePLtNWeV",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"class LogiticRegressionEvaluator(object):\n",
|
||
" def __init__(self, n_features, n_classes):\n",
|
||
" self.log_regression = LogisticRegression(n_features, n_classes).to(device)\n",
|
||
" self.scaler = preprocessing.StandardScaler()\n",
|
||
"\n",
|
||
" def _normalize_dataset(self, X_train, X_test):\n",
|
||
" print(\"Standard Scaling Normalizer\")\n",
|
||
" self.scaler.fit(X_train)\n",
|
||
" X_train = self.scaler.transform(X_train)\n",
|
||
" X_test = self.scaler.transform(X_test)\n",
|
||
" return X_train, X_test\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def _sample_weight_decay():\n",
|
||
" # We selected the l2 regularization parameter from a range of 45 logarithmically spaced values between 10−6 and 105\n",
|
||
" weight_decay = np.logspace(-6, 5, num=45, base=10.0)\n",
|
||
" weight_decay = np.random.choice(weight_decay)\n",
|
||
" print(\"Sampled weight decay:\", weight_decay)\n",
|
||
" return weight_decay\n",
|
||
"\n",
|
||
" def eval(self, test_loader):\n",
|
||
" correct = 0\n",
|
||
" total = 0\n",
|
||
"\n",
|
||
" with torch.no_grad():\n",
|
||
" self.log_regression.eval()\n",
|
||
" for batch_x, batch_y in test_loader:\n",
|
||
" batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
|
||
" logits = self.log_regression(batch_x)\n",
|
||
"\n",
|
||
" predicted = torch.argmax(logits, dim=1)\n",
|
||
" total += batch_y.size(0)\n",
|
||
" correct += (predicted == batch_y).sum().item()\n",
|
||
"\n",
|
||
" final_acc = 100 * correct / total\n",
|
||
" self.log_regression.train()\n",
|
||
" return final_acc\n",
|
||
"\n",
|
||
"\n",
|
||
" def create_data_loaders_from_arrays(self, X_train, y_train, X_test, y_test):\n",
|
||
" X_train, X_test = self._normalize_dataset(X_train, X_test)\n",
|
||
"\n",
|
||
" train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train).type(torch.long))\n",
|
||
" train_loader = torch.utils.data.DataLoader(train, batch_size=396, shuffle=False)\n",
|
||
"\n",
|
||
" test = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test).type(torch.long))\n",
|
||
" test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False)\n",
|
||
" return train_loader, test_loader\n",
|
||
"\n",
|
||
" def train(self, X_train, y_train, X_test, y_test):\n",
|
||
" \n",
|
||
" train_loader, test_loader = self.create_data_loaders_from_arrays(X_train, y_train, X_test, y_test)\n",
|
||
"\n",
|
||
" weight_decay = self._sample_weight_decay()\n",
|
||
"\n",
|
||
" optimizer = torch.optim.Adam(self.log_regression.parameters(), 3e-4, weight_decay=weight_decay)\n",
|
||
" criterion = torch.nn.CrossEntropyLoss()\n",
|
||
"\n",
|
||
" best_accuracy = 0\n",
|
||
"\n",
|
||
" for e in range(200):\n",
|
||
" \n",
|
||
" for batch_x, batch_y in train_loader:\n",
|
||
"\n",
|
||
" batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
|
||
"\n",
|
||
" optimizer.zero_grad()\n",
|
||
"\n",
|
||
" logits = self.log_regression(batch_x)\n",
|
||
"\n",
|
||
" loss = criterion(logits, batch_y)\n",
|
||
"\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" epoch_acc = self.eval(test_loader)\n",
|
||
" \n",
|
||
" if epoch_acc > best_accuracy:\n",
|
||
" #print(\"Saving new model with accuracy {}\".format(epoch_acc))\n",
|
||
" best_accuracy = epoch_acc\n",
|
||
" torch.save(self.log_regression.state_dict(), 'log_regression.pth')\n",
|
||
"\n",
|
||
" print(\"--------------\")\n",
|
||
" print(\"Done training\")\n",
|
||
" print(\"Best accuracy:\", best_accuracy)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "NE716m7SOkaK",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
|
||
"\n",
|
||
"log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "_GC0a14uWRr6",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
""
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
}
|
||
]
|
||
} |