SimCLR/feature_eval/linear_feature_eval.ipynb
2020-04-25 17:10:58 -03:00

480 lines
15 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"colab": {
"name": "linear_feature_eval.ipynb",
"provenance": [],
"include_colab_link": true
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/linear_feature_eval.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "WSgRE1CcLqdS",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 163
},
"outputId": "60a2169b-a652-4d4b-e441-18a1a77ba53a"
},
"source": [
"!pip install gdown"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.21.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.38.0)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.12.0)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.8)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.4.5.1)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "G7YMxsvEZMrX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 219
},
"outputId": "26a6b27f-320c-4562-9b7c-f8b9262f6033"
},
"source": [
"folder_name = 'Mar14_05-52-52_thallessilva'\n",
"\n",
"# !gdown https://drive.google.com/file/d/1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk/view?usp=sharing # ResNet 18 --> 40 epochs trained\n",
"!gdown https://drive.google.com/file/d/1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT/view?usp=sharing # ResNet 18 --> 80 epochs trained\n",
"!unzip Mar14_05-52-52_thallessilva\n",
"!ls"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"unzip: cannot find or open Mar14_05-52-52_thallessilva, Mar14_05-52-52_thallessilva.zip or Mar14_05-52-52_thallessilva.ZIP.\n",
" sample_data 'view?usp=sharing'\n",
"/usr/local/lib/python2.7/dist-packages/gdown/parse_url.py:31: UserWarning: You specified Google Drive Link but it is not the correct link to download the file. Maybe you should try: https://drive.google.com/uc?id=1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT\n",
" .format(url='https://drive.google.com/uc?id={}'.format(file_id))\n",
"Downloading...\n",
"From: https://drive.google.com/file/d/1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT/view?usp=sharing\n",
"To: /content/view?usp=sharing\n",
"69.7kB [00:00, 613kB/s]\n",
"unzip: cannot find or open Mar14_05-52-52_thallessilva, Mar14_05-52-52_thallessilva.zip or Mar14_05-52-52_thallessilva.ZIP.\n",
" sample_data 'view?usp=sharing'\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vEoblAn6RsO7",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 174
},
"outputId": "f32e47cc-6195-4b92-a1aa-ad4328e39269"
},
"source": [
"# download and extract stl10\n",
"!wget http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz\n",
"!tar -zxvf stl10_binary.tar.gz\n",
"!ls"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"--2020-04-25 20:09:26-- http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz\n",
"Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10\n",
"Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2640397119 (2.5G) [application/x-gzip]\n",
"Saving to: stl10_binary.tar.gz\n",
"\n",
"stl10_binary.tar.gz 62%[===========> ] 1.54G 19.3MB/s eta 51s "
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aFnFqIFLLjQZ",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch\n",
"import sys\n",
"import numpy as np\n",
"import os\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"import yaml\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn import preprocessing\n",
"import importlib.util"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lDfbL3w_Z0Od",
"colab_type": "code",
"colab": {}
},
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(\"Using device:\", device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IQMIryc6LjQd",
"colab_type": "code",
"colab": {}
},
"source": [
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
"config"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GxuiXvAKLjQm",
"colab_type": "code",
"colab": {}
},
"source": [
"def _load_stl10(prefix=\"train\"):\n",
" X_train = np.fromfile('./stl10_binary/' + prefix + '_X.bin', dtype=np.uint8)\n",
" y_train = np.fromfile('./stl10_binary/' + prefix + '_y.bin', dtype=np.uint8)\n",
"\n",
" X_train = np.reshape(X_train, (-1, 3, 96, 96)) # CWH\n",
" X_train = np.transpose(X_train, (0, 1, 3, 2)) # CHW\n",
"\n",
" print(\"{} images\".format(prefix))\n",
" print(X_train.shape)\n",
" print(y_train.shape)\n",
" return X_train, y_train - 1"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Xn0xslbELjQq",
"colab_type": "code",
"colab": {}
},
"source": [
"# load STL-10 train data\n",
"X_train, y_train = _load_stl10(\"train\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7shAS6fvXtPG",
"colab_type": "code",
"colab": {}
},
"source": [
"fig, axs = plt.subplots(nrows=2, ncols=6, constrained_layout=False, figsize=(12,4))\n",
"\n",
"for i, ax in enumerate(axs.flat):\n",
" ax.imshow(X_train[i].transpose(1,2,0))\n",
"plt.show()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "YUJ3_xoPLjQv",
"colab_type": "code",
"colab": {}
},
"source": [
"# load STL-10 test data\n",
"X_test, y_test = _load_stl10(\"test\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "QE8sEe_qLjQz",
"colab_type": "text"
},
"source": [
"## Test protocol #1 PCA features"
]
},
{
"cell_type": "code",
"metadata": {
"id": "WFmUZzKoLjQ4",
"colab_type": "code",
"colab": {}
},
"source": [
"scaler = preprocessing.StandardScaler()\n",
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
"\n",
"pca = PCA(n_components=config['model']['out_dim'])\n",
"\n",
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
"\n",
"print(\"PCA features\")\n",
"print(X_train_pca.shape)\n",
"print(X_test_pca.shape)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Yq2N_FpVLjQ8",
"colab_type": "code",
"colab": {}
},
"source": [
"def linear_model_eval(X_train, y_train, X_test, y_test):\n",
" \n",
" clf = LogisticRegression(random_state=0, max_iter=1200, solver='lbfgs', C=1.0)\n",
" clf.fit(X_train, y_train)\n",
" print(\"Logistic Regression feature eval\")\n",
" print(\"Train score:\", clf.score(X_train, y_train))\n",
" print(\"Test score:\", clf.score(X_test, y_test))\n",
" \n",
" print(\"-------------------------------\")\n",
" neigh = KNeighborsClassifier(n_neighbors=10)\n",
" neigh.fit(X_train, y_train)\n",
" print(\"KNN feature eval\")\n",
" print(\"Train score:\", neigh.score(X_train, y_train))\n",
" print(\"Test score:\", neigh.score(X_test, y_test))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6VTolghbLjRA",
"colab_type": "code",
"colab": {}
},
"source": [
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)\n",
"\n",
"## clean up resources\n",
"del X_train_pca\n",
"del X_test_pca"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "5nf4rDtWLjRE",
"colab_type": "text"
},
"source": [
"## Protocol #2 Logisitc Regression"
]
},
{
"cell_type": "code",
"metadata": {
"id": "fYezlvoNVpeT",
"colab_type": "code",
"colab": {}
},
"source": [
"# Load the neural net module\n",
"spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n",
"resnet_module = importlib.util.module_from_spec(spec)\n",
"spec.loader.exec_module(resnet_module)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AxhfD0c7LjRF",
"colab_type": "code",
"colab": {}
},
"source": [
"model = resnet_module.ResNetSimCLR(**config['model'])\n",
"model.eval()\n",
"\n",
"state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n",
"model.load_state_dict(state_dict)\n",
"model = model.to(device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ro6yG6ngLjRI",
"colab_type": "code",
"colab": {}
},
"source": [
"def next_batch(X, y, batch_size):\n",
" for i in range(0, X.shape[0], batch_size):\n",
" X_batch = torch.tensor(X[i: i+batch_size]) / 255.\n",
" y_batch = torch.tensor(y[i: i+batch_size])\n",
" yield X_batch.to(device), y_batch.to(device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oftbHXcdLjRM",
"colab_type": "code",
"colab": {}
},
"source": [
"X_train_feature = []\n",
"\n",
"for batch_x, batch_y in next_batch(X_train, y_train, batch_size=config['batch_size']):\n",
" features, _ = model(batch_x)\n",
" X_train_feature.extend(features.cpu().detach().numpy())\n",
" \n",
"X_train_feature = np.array(X_train_feature)\n",
"\n",
"print(\"Train features\")\n",
"print(X_train_feature.shape)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sverVlKPLjRP",
"colab_type": "code",
"colab": {}
},
"source": [
"X_test_feature = []\n",
"\n",
"for batch_x, batch_y in next_batch(X_test, y_test, batch_size=config['batch_size']):\n",
" features, _ = model(batch_x)\n",
" X_test_feature.extend(features.cpu().detach().numpy())\n",
" \n",
"X_test_feature = np.array(X_test_feature)\n",
"\n",
"print(\"Test features\")\n",
"print(X_test_feature.shape)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "91jHpRQyLjRT",
"colab_type": "code",
"colab": {}
},
"source": [
"scaler = preprocessing.StandardScaler()\n",
"scaler.fit(X_train_feature)\n",
"\n",
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)\n",
"\n",
"del X_train_feature\n",
"del X_test_feature"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "fXy_YX8_b7gL",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}