2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 0,
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "pytorch",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "pytorch"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
"version": "3.6.6"
|
|
|
|
|
},
|
|
|
|
|
"colab": {
|
|
|
|
|
"name": "linear_feature_eval.ipynb",
|
|
|
|
|
"provenance": [],
|
|
|
|
|
"include_colab_link": true
|
|
|
|
|
},
|
|
|
|
|
"accelerator": "GPU"
|
2020-02-17 23:17:10 -03:00
|
|
|
|
},
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cells": [
|
2020-03-10 08:52:30 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "view-in-github",
|
|
|
|
|
"colab_type": "text"
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/linear_feature_eval.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
2020-03-10 08:52:30 -03:00
|
|
|
|
]
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "WSgRE1CcLqdS",
|
|
|
|
|
"colab_type": "code",
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
|
"height": 163
|
|
|
|
|
},
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"outputId": "855d3d81-1171-42c9-b4a6-957c6528fcca"
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"!pip install gdown"
|
|
|
|
|
],
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"execution_count": 1,
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
|
|
|
|
|
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.38.0)\n",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.21.0)\n",
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.12.0)\n",
|
|
|
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.4.5.1)\n",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n",
|
|
|
|
|
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n",
|
|
|
|
|
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.8)\n"
|
2020-04-25 17:10:58 -03:00
|
|
|
|
],
|
|
|
|
|
"name": "stdout"
|
|
|
|
|
}
|
|
|
|
|
]
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "G7YMxsvEZMrX",
|
|
|
|
|
"colab_type": "code",
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"height": 272
|
2020-04-25 17:10:58 -03:00
|
|
|
|
},
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"outputId": "869571ca-c1fa-40fd-e687-f0962aba76e5"
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"folder_name = 'resnet-18_80-epochs'\n",
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"\n",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"# !gdown https://drive.google.com/uc?id=1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk # ResNet 18 --> 40 epochs trained\n",
|
|
|
|
|
"!gdown https://drive.google.com/uc?id=1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT # ResNet 18 --> 80 epochs trained\n",
|
|
|
|
|
"!unzip resnet-18_80-epochs\n",
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"!ls"
|
|
|
|
|
],
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"execution_count": 2,
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"Downloading...\n",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"From: https://drive.google.com/uc?id=1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT\n",
|
|
|
|
|
"To: /content/resnet-18_80-epochs.zip\n",
|
|
|
|
|
"43.3MB [00:00, 93.1MB/s]\n",
|
|
|
|
|
"Archive: resnet-18_80-epochs.zip\n",
|
|
|
|
|
" creating: resnet-18_80-epochs/\n",
|
|
|
|
|
" creating: resnet-18_80-epochs/checkpoints/\n",
|
|
|
|
|
" inflating: resnet-18_80-epochs/checkpoints/config.yaml \n",
|
|
|
|
|
" inflating: resnet-18_80-epochs/checkpoints/model.pth \n",
|
|
|
|
|
" inflating: resnet-18_80-epochs/checkpoints/resnet_simclr.py \n",
|
|
|
|
|
" inflating: resnet-18_80-epochs/events.out.tfevents.1584175972.thallessilva.7272.0 \n",
|
|
|
|
|
" resnet-18_80-epochs\t stl10_binary\t\t 'view?usp=sharing'\n",
|
|
|
|
|
" resnet-18_80-epochs.zip stl10_binary.tar.gz\n",
|
|
|
|
|
" sample_data\t\t stl10_binary.tar.gz.1\n"
|
|
|
|
|
],
|
|
|
|
|
"name": "stdout"
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "Muj3TrwSNLEu",
|
|
|
|
|
"colab_type": "code",
|
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
|
"height": 72
|
|
|
|
|
},
|
|
|
|
|
"outputId": "4d67b96d-6c5d-4703-d2dc-1cb599d27db2"
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"!ls"
|
|
|
|
|
],
|
|
|
|
|
"execution_count": 3,
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
" resnet-18_80-epochs\t stl10_binary\t\t 'view?usp=sharing'\n",
|
|
|
|
|
" resnet-18_80-epochs.zip stl10_binary.tar.gz\n",
|
|
|
|
|
" sample_data\t\t stl10_binary.tar.gz.1\n"
|
2020-04-25 17:10:58 -03:00
|
|
|
|
],
|
|
|
|
|
"name": "stdout"
|
|
|
|
|
}
|
|
|
|
|
]
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "vEoblAn6RsO7",
|
|
|
|
|
"colab_type": "code",
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"height": 417
|
2020-04-25 17:10:58 -03:00
|
|
|
|
},
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"outputId": "9906b2c4-bbfa-45dc-9e66-32464e0dbaa3"
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"# download and extract stl10\n",
|
|
|
|
|
"!wget http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz\n",
|
|
|
|
|
"!tar -zxvf stl10_binary.tar.gz\n",
|
|
|
|
|
"!ls"
|
|
|
|
|
],
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"execution_count": 4,
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"--2020-04-25 21:21:34-- http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz\n",
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10\n",
|
|
|
|
|
"Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.\n",
|
|
|
|
|
"HTTP request sent, awaiting response... 200 OK\n",
|
|
|
|
|
"Length: 2640397119 (2.5G) [application/x-gzip]\n",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"Saving to: ‘stl10_binary.tar.gz.2’\n",
|
2020-04-25 17:10:58 -03:00
|
|
|
|
"\n",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"stl10_binary.tar.gz 100%[===================>] 2.46G 63.5MB/s in 34s \n",
|
|
|
|
|
"\n",
|
|
|
|
|
"2020-04-25 21:22:07 (75.1 MB/s) - ‘stl10_binary.tar.gz.2’ saved [2640397119/2640397119]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"stl10_binary/\n",
|
|
|
|
|
"stl10_binary/test_X.bin\n",
|
|
|
|
|
"stl10_binary/test_y.bin\n",
|
|
|
|
|
"stl10_binary/train_X.bin\n",
|
|
|
|
|
"stl10_binary/train_y.bin\n",
|
|
|
|
|
"stl10_binary/unlabeled_X.bin\n",
|
|
|
|
|
"stl10_binary/class_names.txt\n",
|
|
|
|
|
"stl10_binary/fold_indices.txt\n",
|
|
|
|
|
" resnet-18_80-epochs\t stl10_binary\t\t stl10_binary.tar.gz.2\n",
|
|
|
|
|
" resnet-18_80-epochs.zip stl10_binary.tar.gz\t 'view?usp=sharing'\n",
|
|
|
|
|
" sample_data\t\t stl10_binary.tar.gz.1\n"
|
2020-04-25 17:10:58 -03:00
|
|
|
|
],
|
|
|
|
|
"name": "stdout"
|
|
|
|
|
}
|
|
|
|
|
]
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "aFnFqIFLLjQZ",
|
|
|
|
|
"colab_type": "code",
|
|
|
|
|
"colab": {}
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"import torch\n",
|
|
|
|
|
"import sys\n",
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
"import os\n",
|
|
|
|
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
|
|
|
|
"import yaml\n",
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"from sklearn.decomposition import PCA\n",
|
|
|
|
|
"from sklearn.linear_model import LogisticRegression\n",
|
|
|
|
|
"from sklearn import preprocessing\n",
|
|
|
|
|
"import importlib.util"
|
|
|
|
|
],
|
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
|
|
|
|
},
|
2020-03-13 22:56:04 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "lDfbL3w_Z0Od",
|
|
|
|
|
"colab_type": "code",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
|
"height": 35
|
|
|
|
|
},
|
|
|
|
|
"outputId": "52656481-7f67-452a-d85a-c10707c29f43"
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
|
|
|
|
"print(\"Using device:\", device)"
|
|
|
|
|
],
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"execution_count": 6,
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"Using device: cuda\n"
|
|
|
|
|
],
|
|
|
|
|
"name": "stdout"
|
|
|
|
|
}
|
|
|
|
|
]
|
2020-03-13 22:56:04 -03:00
|
|
|
|
},
|
2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "IQMIryc6LjQd",
|
|
|
|
|
"colab_type": "code",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
|
"height": 217
|
|
|
|
|
},
|
|
|
|
|
"outputId": "5b02308c-a5ab-4bbc-f158-5e6129164017"
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
|
|
|
|
|
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
|
|
|
|
|
"config"
|
|
|
|
|
],
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"execution_count": 7,
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "execute_result",
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"{'batch_size': 512,\n",
|
|
|
|
|
" 'dataset': {'input_shape': '(96,96,3)',\n",
|
|
|
|
|
" 'num_workers': 0,\n",
|
|
|
|
|
" 's': 1,\n",
|
|
|
|
|
" 'valid_size': 0.05},\n",
|
|
|
|
|
" 'epochs': 40,\n",
|
|
|
|
|
" 'eval_every_n_epochs': 1,\n",
|
|
|
|
|
" 'fine_tune_from': 'Mar13_22-46-30_thallessilva',\n",
|
|
|
|
|
" 'log_every_n_steps': 50,\n",
|
|
|
|
|
" 'loss': {'temperature': 0.5, 'use_cosine_similarity': True},\n",
|
|
|
|
|
" 'model': {'base_model': 'resnet18', 'out_dim': 256}}"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {
|
|
|
|
|
"tags": []
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 7
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "udi8OnvzMUEt",
|
|
|
|
|
"colab_type": "code",
|
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
|
"height": 35
|
|
|
|
|
},
|
|
|
|
|
"outputId": "bc6cdde2-8f80-44a1-fd6d-0fe5f58a1eb1"
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"checkpoints_folder"
|
|
|
|
|
],
|
|
|
|
|
"execution_count": 8,
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "execute_result",
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"'resnet-18_80-epochs/checkpoints'"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {
|
|
|
|
|
"tags": []
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 8
|
|
|
|
|
}
|
|
|
|
|
]
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "GxuiXvAKLjQm",
|
|
|
|
|
"colab_type": "code",
|
|
|
|
|
"colab": {}
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"def _load_stl10(prefix=\"train\"):\n",
|
|
|
|
|
" X_train = np.fromfile('./stl10_binary/' + prefix + '_X.bin', dtype=np.uint8)\n",
|
|
|
|
|
" y_train = np.fromfile('./stl10_binary/' + prefix + '_y.bin', dtype=np.uint8)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" X_train = np.reshape(X_train, (-1, 3, 96, 96)) # CWH\n",
|
|
|
|
|
" X_train = np.transpose(X_train, (0, 1, 3, 2)) # CHW\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" print(\"{} images\".format(prefix))\n",
|
|
|
|
|
" print(X_train.shape)\n",
|
|
|
|
|
" print(y_train.shape)\n",
|
|
|
|
|
" return X_train, y_train - 1"
|
|
|
|
|
],
|
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
|
|
|
|
},
|
2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "Xn0xslbELjQq",
|
|
|
|
|
"colab_type": "code",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
|
"height": 72
|
|
|
|
|
},
|
|
|
|
|
"outputId": "f20616d0-3e7f-4f53-bcae-529bbe7f4d40"
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"# load STL-10 train data\n",
|
|
|
|
|
"X_train, y_train = _load_stl10(\"train\")"
|
|
|
|
|
],
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"execution_count": 10,
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"train images\n",
|
|
|
|
|
"(5000, 3, 96, 96)\n",
|
|
|
|
|
"(5000,)\n"
|
|
|
|
|
],
|
|
|
|
|
"name": "stdout"
|
|
|
|
|
}
|
|
|
|
|
]
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
2020-02-17 23:17:10 -03:00
|
|
|
|
{
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "7shAS6fvXtPG",
|
|
|
|
|
"colab_type": "code",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
|
"height": 266
|
|
|
|
|
},
|
|
|
|
|
"outputId": "8217a3e4-472e-4dfc-eb75-dc6cd29072e5"
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"fig, axs = plt.subplots(nrows=2, ncols=6, constrained_layout=False, figsize=(12,4))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for i, ax in enumerate(axs.flat):\n",
|
|
|
|
|
" ax.imshow(X_train[i].transpose(1,2,0))\n",
|
|
|
|
|
"plt.show()"
|
|
|
|
|
],
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"execution_count": 11,
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "display_data",
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAr8AAAD5CAYAAAAuh122AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOy9eZAdx33n+cnMut7Z7/XdjQYa3TgJgAQI3qQgihQl67AkSrZsSbblGXvCHm9s7O7ErDd2NmJmwrM7+4d3Iya8MeFdez2eXUljjy2NxpJlizoskqJ4gAAPACRuoNHd6Pt8d12ZuX+8bqBJUbIlYWQt9T6IAvBe1cvKyl9V1rd++ftlCWstHTp06NChQ4cOHTr8NCD/vivQoUOHDh06dOjQocOPi4747dChQ4cOHTp06PBTQ0f8dujQoUOHDh06dPipoSN+O3To0KFDhw4dOvzU0BG/HTp06NChQ4cOHX5q6IjfDh06dOjQoUOHDj81/EjiVwjxPiHEBSHEZSHE/3irKtXhx0/Hlm8fOrZ8+9Cx5duDjh3fPnRs+fZA/LDz/AohFHAReA9wHTgBfNJae/bWVa/Dj4OOLd8+dGz59qFjy7cHHTu+fejY8u3Dj+L5vRe4bK29aq2Ngf8IfOTWVKvDj5mOLd8+dGz59qFjy7cHHTu+fejY8m3CjyJ+twHTWz5f3/iuw///6Njy7UPHlm8fOrZ8e9Cx49uHji3fJjj/pXcghPgN4DcAPM+9a6C/vGXllg3t5l9bv7z5ncVsbNP+ZO3N39z46RsLe4tyvtc6wIqNXW+Uh0RJhTEWYzVgUQqUUhhjwAqsBXujTpslby1fbLTBmw92y9rNv+zGkW58FpubC7GxidgsDSHeWJYQor3XGwVuqcOW/05Mzi1ba/veugH+drba0vXUXb29uXa9NvdPu0GkUEgpbhyxBYy1WKuxFoQVsLFeiHb9pZBYKxAYbhZlERt/sBZtdHt/UoI1WGtBSBDihtVu7NC2zxGLwWAQbNTRWiwCi8IikVbiKIWSAq01iU5wHIHjCIy1GGuQQnKz9dvHorXG6HZdpZRIV23Ud6MCQrbtZCzGGKwxKKGQUoFtl2HQWGx7u4122LQnWDZDkpRQgNj4bFlda1Kvh999Qv2Qtgwyzl0DO3twcVhcWsbN+DjSodaoMdDbR6VWw/E8HOHQqNcwwoKwSCnIBAGtVgtrLFZITJJijSGbyyGVg041YRRijGnbTYAxhnwmg7AQJylWGLTRSNW+5nzXI6wlJLHBIkHaG4ux+sa14XkexhhSnaKkBARJnCClROsUKSSBF5BEhiTWCARSKoQUICwogTEWa+zGvgVCWISE1CRYDI7jYi1oY27aCos1BrlxHhqtUapdrpTyxnnX3p/EbNjbbpzaxur2GSggqluSlv6hbbnVjsBdQjkb5+rm9SBQEoS1jPanXFsEBbgKUiOJTbu+CHAEFIuGICeprGiaocJuFiXAF5asZ3AdkAisMeRzEs+RpAlcr1jCVG7s9WbPfaPv2+hitQVjN6//9nl9s++UgEXdPMIb/ZoEhLBoK7jZu2y0qUlRSqGUj7UGrMVYCxjatb1ZDbvRn2PtRt/T7m8FIIVAbt4Hbt5kbvwrBSgpcDa2Fxvt5irBxHp0y/pX33PuGt0+QKMRorXGcxWOclBO+xwPI42QEteTNOoNAt9vn6PGopRq969C4jgOaRyipMK2Lz68bJYwjPE9RZrGGAu+H1Cv1fF9n0ajgVQOSZqSyWTxXIckSYjj9pLL+lg00nGIwgQlIJvNolyXVrOFTlNcV2KNxXXb10+SpijloNMEKQWO47bb+YZN2jZLjaZeD/EzPoV8F3EUk0YNlOPgeRnWK+vkcrn2LdNqPM+lUW8A7TI91yVKYuI0wXUchJC0miGe74KFTCZLtbJOPp8lSVKCTAadalKdkOiUzRN+Yalxy2wJ3PXDltPhlvCWtvxRxO8MsH3L55GN796AtfYPgT8E2LF9wP72P/3EjXWbN3tr2wLhLX7bFhQCjE2xWrfvWdZitcFogzHtZVMotLtRfeP3N/ZlDQKzpcxNMSSA9k0MLEIoLODnenBcSbW+Sm9fkSN3HmL/wW14fhNrLWliWV2u8+or55i+vErU0mAV1iYgEtpXp4cQsn2T3DjezeXm8W9+lljb7riEaN80hZQI1f69UqotsqRECu/G/8WGGDQSjGh33ptCGizG3GyDX/71fzV5q2y5bXvZ/sZ/cwyEQBtNK4pIrEEpha8cPM/Bk+07hJQSbQ1JnGAMSOXiKAdrQcmErO/huQEIgYfCMQIdpaSJQQkH33XxUoGOI0KlaMQxYVjH6Bg/l8f1fBKr8YIAm8SElTpRM0VZhXQlVrXbsdaE9Yoi1nmqDYfVCkQmRWBI45haZRlXOPR2BwwNZ+jqDShkIRdEmLRGIZsnyHjoNCWNLDpMiK1GeYpUaYw05HAxYUQDTSaTwUldqpU6USPEUR6e76HwSExC4kdEbkLWz1BQAVZbEmHQyhLaiDBs4aeS7mwXjnKxjsKQ8ru/+8T3MOMPZ8vxQ/3297/6z3ju6Zc4ffEa9xzez8yZM1yYusDtx26jmCtw7fXrtIyHm/FJkghtLVIauss55qevku8qsRoJBke2cX3iOvW5FfrKfUQW4iRCGEitJVUWgSErBP2FboaHtjN2YBdf+daXWFq/zrahLj75849z4fgqX/zT7yBMgWJ3Ee00SVWTWtLECItrBcoaPN/FOJZMMUOz3iCKNEp5mDhh7foyOeXRVRyitqaRsUvW70I4CuNbUhHhakPOLzDSt4M79x/mzoOHCKRiYnWab7z4LQa29bNYX+DM1VfQQUxKCMai4wghDBk/YHVtlUzGx1OaUqGA0RbPMXR1FXBkhlpLUzcJYdwixqNST2nVVhkfznH6T9Z/JFtutaNSjg2yJaRtSz5rBO97UPDfPi74rf9F82sfaPLv/yrk2qJgqGDp61F88h+X+aunUk6+kPLYXsud9zi4Oyz/9n+PmF/Nc9e9DkODPm7dEKzVKTl1+gOXxfWUBM2HHgtQymdtUvDbXxKsSh9FuvEU72BtvClbEUKQ8aAReqy0Wu2HSiFBCiQuWIlRYHVIDoFA4cgAKzVq4yHorpEqr84OkliDZwWJsKRxA20SCrkSXV07SJMGaZRgpCHRCVJIpFAoCZCQJk2sTTA6IhsEeEoROIpACTKeQ86FJA2JkxBjNdqkGJPie4pS4BIoS9mx9OY98o6g2xX0ZuHDn790y/rXA/tG7f/0j97PCydOML5vF0O9JTJBnsuXJ7GEjI4f4NLEVXp6hjh5/FkO7D/Aysoyz714hvc9+hip02KtVkWlCY++9wNUV5Z57pmvc89dD7O0tIwKHGbnZwhyksDPMzK8k2YzotjVxee/8CVGRndw4pWXGB/fT5ArUspmKJd9ZucXGBwqMNhX5utPPcPyQsyd+3bQVS4xszTPnXcdZfryNfaP7aQVhkjlsLC4xNzSEkPb9hA35tk5tI312gpxFOH5RWr1NRbmF+gd6ufYIw9y8eIE+/bt5qWXzrBjeIzV2YskrQTpBVydmmf3beOAZL06Q7GYZ2WpyeL8Gr/66U+wsrTCtZlppuan0LFhfNdOnvjrZ/iHv/ZJzp47w8r8KqVAIB2YX00wQrF/3x6mFy6RyWWJWy7rq8v8n585ectsKYT44RKrOtwq3tKWP4r4PQHsEUKM0Tb+J4BPff+fvPEcuCFY3yLpblOk3ngC37qNEG1Pk33j8ga2eNLMhtdGbtkmk8nQarWIY3CUBGEw1qCkSyboorurnyCnefd7DzMwmCGOq6ytXaRvIEc2pzA2plDyGRk9wpf+43EmLi0DAW2/it7waKobx7IpepW6+d2msG+vuymQbwjbjUVuXcR3R6oIQFmBfItLTHz3M8Vb8UPYEnzfJ0kTIm0xSqATQ2IssdVkHYkVbX+rMG2PqFEChECq9mmXJBEJGm0NvjV4novjZsAqkiRFYzEiBSkxniIXZFA4GCEgjbGOws9mMBKE0SQ2BmFoxCk6MQTawbEBb
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 864x288 with 12 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {
|
|
|
|
|
"tags": [],
|
|
|
|
|
"needs_background": "light"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
]
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "YUJ3_xoPLjQv",
|
|
|
|
|
"colab_type": "code",
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"colab": {
|
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
|
"height": 72
|
|
|
|
|
},
|
|
|
|
|
"outputId": "a2b962aa-5f10-4c43-a158-2aea21369e26"
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"# load STL-10 test data\n",
|
|
|
|
|
"X_test, y_test = _load_stl10(\"test\")"
|
|
|
|
|
],
|
2020-04-25 18:24:22 -03:00
|
|
|
|
"execution_count": 12,
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"test images\n",
|
|
|
|
|
"(8000, 3, 96, 96)\n",
|
|
|
|
|
"(8000,)\n"
|
|
|
|
|
],
|
|
|
|
|
"name": "stdout"
|
|
|
|
|
}
|
|
|
|
|
]
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "QE8sEe_qLjQz",
|
|
|
|
|
"colab_type": "text"
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Test protocol #1 PCA features"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "WFmUZzKoLjQ4",
|
|
|
|
|
"colab_type": "code",
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"colab": {}
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"scaler = preprocessing.StandardScaler()\n",
|
|
|
|
|
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"pca = PCA(n_components=config['model']['out_dim'])\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
|
|
|
|
|
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"print(\"PCA features\")\n",
|
|
|
|
|
"print(X_train_pca.shape)\n",
|
|
|
|
|
"print(X_test_pca.shape)"
|
|
|
|
|
],
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "Yq2N_FpVLjQ8",
|
|
|
|
|
"colab_type": "code",
|
|
|
|
|
"colab": {}
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"def linear_model_eval(X_train, y_train, X_test, y_test):\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" clf = LogisticRegression(random_state=0, max_iter=1200, solver='lbfgs', C=1.0)\n",
|
|
|
|
|
" clf.fit(X_train, y_train)\n",
|
|
|
|
|
" print(\"Logistic Regression feature eval\")\n",
|
|
|
|
|
" print(\"Train score:\", clf.score(X_train, y_train))\n",
|
|
|
|
|
" print(\"Test score:\", clf.score(X_test, y_test))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" print(\"-------------------------------\")\n",
|
|
|
|
|
" neigh = KNeighborsClassifier(n_neighbors=10)\n",
|
|
|
|
|
" neigh.fit(X_train, y_train)\n",
|
|
|
|
|
" print(\"KNN feature eval\")\n",
|
|
|
|
|
" print(\"Train score:\", neigh.score(X_train, y_train))\n",
|
|
|
|
|
" print(\"Test score:\", neigh.score(X_test, y_test))"
|
|
|
|
|
],
|
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "6VTolghbLjRA",
|
|
|
|
|
"colab_type": "code",
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"colab": {}
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"## clean up resources\n",
|
|
|
|
|
"del X_train_pca\n",
|
|
|
|
|
"del X_test_pca"
|
|
|
|
|
],
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "5nf4rDtWLjRE",
|
|
|
|
|
"colab_type": "text"
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Protocol #2 Logisitc Regression"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "fYezlvoNVpeT",
|
|
|
|
|
"colab_type": "code",
|
|
|
|
|
"colab": {}
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Load the neural net module\n",
|
|
|
|
|
"spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n",
|
|
|
|
|
"resnet_module = importlib.util.module_from_spec(spec)\n",
|
|
|
|
|
"spec.loader.exec_module(resnet_module)"
|
|
|
|
|
],
|
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "AxhfD0c7LjRF",
|
|
|
|
|
"colab_type": "code",
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"colab": {}
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"model = resnet_module.ResNetSimCLR(**config['model'])\n",
|
|
|
|
|
"model.eval()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n",
|
|
|
|
|
"model.load_state_dict(state_dict)\n",
|
|
|
|
|
"model = model.to(device)"
|
|
|
|
|
],
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "ro6yG6ngLjRI",
|
|
|
|
|
"colab_type": "code",
|
|
|
|
|
"colab": {}
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"def next_batch(X, y, batch_size):\n",
|
|
|
|
|
" for i in range(0, X.shape[0], batch_size):\n",
|
|
|
|
|
" X_batch = torch.tensor(X[i: i+batch_size]) / 255.\n",
|
|
|
|
|
" y_batch = torch.tensor(y[i: i+batch_size])\n",
|
|
|
|
|
" yield X_batch.to(device), y_batch.to(device)"
|
|
|
|
|
],
|
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "oftbHXcdLjRM",
|
|
|
|
|
"colab_type": "code",
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"colab": {}
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"X_train_feature = []\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for batch_x, batch_y in next_batch(X_train, y_train, batch_size=config['batch_size']):\n",
|
|
|
|
|
" features, _ = model(batch_x)\n",
|
|
|
|
|
" X_train_feature.extend(features.cpu().detach().numpy())\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"X_train_feature = np.array(X_train_feature)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"print(\"Train features\")\n",
|
|
|
|
|
"print(X_train_feature.shape)"
|
|
|
|
|
],
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "sverVlKPLjRP",
|
|
|
|
|
"colab_type": "code",
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"colab": {}
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"X_test_feature = []\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for batch_x, batch_y in next_batch(X_test, y_test, batch_size=config['batch_size']):\n",
|
|
|
|
|
" features, _ = model(batch_x)\n",
|
|
|
|
|
" X_test_feature.extend(features.cpu().detach().numpy())\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"X_test_feature = np.array(X_test_feature)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"print(\"Test features\")\n",
|
|
|
|
|
"print(X_test_feature.shape)"
|
|
|
|
|
],
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
|
|
|
|
"id": "91jHpRQyLjRT",
|
|
|
|
|
"colab_type": "code",
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"colab": {}
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
"scaler = preprocessing.StandardScaler()\n",
|
|
|
|
|
"scaler.fit(X_train_feature)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"del X_train_feature\n",
|
|
|
|
|
"del X_test_feature"
|
|
|
|
|
],
|
2020-03-14 07:16:36 -03:00
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
2020-03-14 07:11:29 -03:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"metadata": {
|
2020-03-14 11:34:51 -03:00
|
|
|
|
"id": "fXy_YX8_b7gL",
|
2020-03-14 07:11:29 -03:00
|
|
|
|
"colab_type": "code",
|
|
|
|
|
"colab": {}
|
|
|
|
|
},
|
|
|
|
|
"source": [
|
|
|
|
|
""
|
|
|
|
|
],
|
|
|
|
|
"execution_count": 0,
|
|
|
|
|
"outputs": []
|
2020-02-17 23:17:10 -03:00
|
|
|
|
}
|
2020-03-14 07:11:29 -03:00
|
|
|
|
]
|
|
|
|
|
}
|