{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "pytorch",
      "language": "python",
      "name": "pytorch"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.6"
    },
    "colab": {
      "name": "linear_feature_eval.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/linear_feature_eval.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WSgRE1CcLqdS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install gdown"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G7YMxsvEZMrX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "folder_name = 'Mar14_05-52-52_thallessilva'\n",
        "\n",
        "# !gdown https://drive.google.com/uc?id=12kKgvo4h41G9qnDdhDnZXFlR5_aqvaVR # ResNet 18 --> 40 epochs trained\n",
        "!gdown https://drive.google.com/uc?id=1LjuZ1RmhotrnugprRQc2Exk0EbQHMJhL # ResNet 18 --> 80 epochs trained\n",
        "!unzip Mar14_05-52-52_thallessilva\n",
        "!ls"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vEoblAn6RsO7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# download and extract stl10\n",
        "!wget http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz\n",
        "!tar -zxvf stl10_binary.tar.gz\n",
        "!ls"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aFnFqIFLLjQZ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import sys\n",
        "import numpy as np\n",
        "import os\n",
        "from sklearn.neighbors import KNeighborsClassifier\n",
        "import yaml\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.decomposition import PCA\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn import preprocessing\n",
        "import importlib.util"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lDfbL3w_Z0Od",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "print(\"Using device:\", device)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IQMIryc6LjQd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
        "config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
        "config"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GxuiXvAKLjQm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def _load_stl10(prefix=\"train\"):\n",
        "    X_train = np.fromfile('./stl10_binary/' + prefix + '_X.bin', dtype=np.uint8)\n",
        "    y_train = np.fromfile('./stl10_binary/' + prefix + '_y.bin', dtype=np.uint8)\n",
        "\n",
        "    X_train = np.reshape(X_train, (-1, 3, 96, 96)) # CWH\n",
        "    X_train = np.transpose(X_train, (0, 1, 3, 2)) # CHW\n",
        "\n",
        "    print(\"{} images\".format(prefix))\n",
        "    print(X_train.shape)\n",
        "    print(y_train.shape)\n",
        "    return X_train, y_train - 1"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Xn0xslbELjQq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# load STL-10 train data\n",
        "X_train, y_train = _load_stl10(\"train\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7shAS6fvXtPG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fig, axs = plt.subplots(nrows=2, ncols=6, constrained_layout=False, figsize=(12,4))\n",
        "\n",
        "for i, ax in enumerate(axs.flat):\n",
        "  ax.imshow(X_train[i].transpose(1,2,0))\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YUJ3_xoPLjQv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# load STL-10 test data\n",
        "X_test, y_test = _load_stl10(\"test\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QE8sEe_qLjQz",
        "colab_type": "text"
      },
      "source": [
        "## Test protocol #1 PCA features"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WFmUZzKoLjQ4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "scaler = preprocessing.StandardScaler()\n",
        "scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
        "\n",
        "pca = PCA(n_components=config['model']['out_dim'])\n",
        "\n",
        "X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
        "X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
        "\n",
        "print(\"PCA features\")\n",
        "print(X_train_pca.shape)\n",
        "print(X_test_pca.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yq2N_FpVLjQ8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def linear_model_eval(X_train, y_train, X_test, y_test):\n",
        "    \n",
        "    clf = LogisticRegression(random_state=0, max_iter=1200, solver='lbfgs', C=1.0)\n",
        "    clf.fit(X_train, y_train)\n",
        "    print(\"Logistic Regression feature eval\")\n",
        "    print(\"Train score:\", clf.score(X_train, y_train))\n",
        "    print(\"Test score:\", clf.score(X_test, y_test))\n",
        "    \n",
        "    print(\"-------------------------------\")\n",
        "    neigh = KNeighborsClassifier(n_neighbors=10)\n",
        "    neigh.fit(X_train, y_train)\n",
        "    print(\"KNN feature eval\")\n",
        "    print(\"Train score:\", neigh.score(X_train, y_train))\n",
        "    print(\"Test score:\", neigh.score(X_test, y_test))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6VTolghbLjRA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)\n",
        "\n",
        "## clean up resources\n",
        "del X_train_pca\n",
        "del X_test_pca"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5nf4rDtWLjRE",
        "colab_type": "text"
      },
      "source": [
        "## Protocol #2 Logisitc Regression"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fYezlvoNVpeT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Load the neural net module\n",
        "spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n",
        "resnet_module = importlib.util.module_from_spec(spec)\n",
        "spec.loader.exec_module(resnet_module)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AxhfD0c7LjRF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = resnet_module.ResNetSimCLR(**config['model'])\n",
        "model.eval()\n",
        "\n",
        "state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n",
        "model.load_state_dict(state_dict)\n",
        "model = model.to(device)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ro6yG6ngLjRI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def next_batch(X, y, batch_size):\n",
        "    for i in range(0, X.shape[0], batch_size):\n",
        "        X_batch = torch.tensor(X[i: i+batch_size]) / 255.\n",
        "        y_batch = torch.tensor(y[i: i+batch_size])\n",
        "        yield X_batch.to(device), y_batch.to(device)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oftbHXcdLjRM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X_train_feature = []\n",
        "\n",
        "for batch_x, batch_y in next_batch(X_train, y_train, batch_size=config['batch_size']):\n",
        "    features, _ = model(batch_x)\n",
        "    X_train_feature.extend(features.cpu().detach().numpy())\n",
        "    \n",
        "X_train_feature = np.array(X_train_feature)\n",
        "\n",
        "print(\"Train features\")\n",
        "print(X_train_feature.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sverVlKPLjRP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X_test_feature = []\n",
        "\n",
        "for batch_x, batch_y in next_batch(X_test, y_test, batch_size=config['batch_size']):\n",
        "    features, _ = model(batch_x)\n",
        "    X_test_feature.extend(features.cpu().detach().numpy())\n",
        "    \n",
        "X_test_feature = np.array(X_test_feature)\n",
        "\n",
        "print(\"Test features\")\n",
        "print(X_test_feature.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "91jHpRQyLjRT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "scaler = preprocessing.StandardScaler()\n",
        "scaler.fit(X_train_feature)\n",
        "\n",
        "linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)\n",
        "\n",
        "del X_train_feature\n",
        "del X_test_feature"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fXy_YX8_b7gL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}