mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
480 lines
15 KiB
Plaintext
480 lines
15 KiB
Plaintext
{
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0,
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "pytorch",
|
||
"language": "python",
|
||
"name": "pytorch"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.6"
|
||
},
|
||
"colab": {
|
||
"name": "linear_feature_eval.ipynb",
|
||
"provenance": [],
|
||
"include_colab_link": true
|
||
},
|
||
"accelerator": "GPU"
|
||
},
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "view-in-github",
|
||
"colab_type": "text"
|
||
},
|
||
"source": [
|
||
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/linear_feature_eval.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "WSgRE1CcLqdS",
|
||
"colab_type": "code",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/",
|
||
"height": 163
|
||
},
|
||
"outputId": "60a2169b-a652-4d4b-e441-18a1a77ba53a"
|
||
},
|
||
"source": [
|
||
"!pip install gdown"
|
||
],
|
||
"execution_count": 4,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
|
||
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.21.0)\n",
|
||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.38.0)\n",
|
||
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.12.0)\n",
|
||
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n",
|
||
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.8)\n",
|
||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.4.5.1)\n",
|
||
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n"
|
||
],
|
||
"name": "stdout"
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "G7YMxsvEZMrX",
|
||
"colab_type": "code",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/",
|
||
"height": 219
|
||
},
|
||
"outputId": "26a6b27f-320c-4562-9b7c-f8b9262f6033"
|
||
},
|
||
"source": [
|
||
"folder_name = 'Mar14_05-52-52_thallessilva'\n",
|
||
"\n",
|
||
"# !gdown https://drive.google.com/file/d/1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk/view?usp=sharing # ResNet 18 --> 40 epochs trained\n",
|
||
"!gdown https://drive.google.com/file/d/1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT/view?usp=sharing # ResNet 18 --> 80 epochs trained\n",
|
||
"!unzip Mar14_05-52-52_thallessilva\n",
|
||
"!ls"
|
||
],
|
||
"execution_count": 5,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"text": [
|
||
"unzip: cannot find or open Mar14_05-52-52_thallessilva, Mar14_05-52-52_thallessilva.zip or Mar14_05-52-52_thallessilva.ZIP.\n",
|
||
" sample_data 'view?usp=sharing'\n",
|
||
"/usr/local/lib/python2.7/dist-packages/gdown/parse_url.py:31: UserWarning: You specified Google Drive Link but it is not the correct link to download the file. Maybe you should try: https://drive.google.com/uc?id=1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT\n",
|
||
" .format(url='https://drive.google.com/uc?id={}'.format(file_id))\n",
|
||
"Downloading...\n",
|
||
"From: https://drive.google.com/file/d/1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT/view?usp=sharing\n",
|
||
"To: /content/view?usp=sharing\n",
|
||
"69.7kB [00:00, 613kB/s]\n",
|
||
"unzip: cannot find or open Mar14_05-52-52_thallessilva, Mar14_05-52-52_thallessilva.zip or Mar14_05-52-52_thallessilva.ZIP.\n",
|
||
" sample_data 'view?usp=sharing'\n"
|
||
],
|
||
"name": "stdout"
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "vEoblAn6RsO7",
|
||
"colab_type": "code",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/",
|
||
"height": 174
|
||
},
|
||
"outputId": "f32e47cc-6195-4b92-a1aa-ad4328e39269"
|
||
},
|
||
"source": [
|
||
"# download and extract stl10\n",
|
||
"!wget http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz\n",
|
||
"!tar -zxvf stl10_binary.tar.gz\n",
|
||
"!ls"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"text": [
|
||
"--2020-04-25 20:09:26-- http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz\n",
|
||
"Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10\n",
|
||
"Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.\n",
|
||
"HTTP request sent, awaiting response... 200 OK\n",
|
||
"Length: 2640397119 (2.5G) [application/x-gzip]\n",
|
||
"Saving to: ‘stl10_binary.tar.gz’\n",
|
||
"\n",
|
||
"stl10_binary.tar.gz 62%[===========> ] 1.54G 19.3MB/s eta 51s "
|
||
],
|
||
"name": "stdout"
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "aFnFqIFLLjQZ",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"import torch\n",
|
||
"import sys\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
"import yaml\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from sklearn.decomposition import PCA\n",
|
||
"from sklearn.linear_model import LogisticRegression\n",
|
||
"from sklearn import preprocessing\n",
|
||
"import importlib.util"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "lDfbL3w_Z0Od",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||
"print(\"Using device:\", device)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "IQMIryc6LjQd",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
|
||
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
|
||
"config"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "GxuiXvAKLjQm",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"def _load_stl10(prefix=\"train\"):\n",
|
||
" X_train = np.fromfile('./stl10_binary/' + prefix + '_X.bin', dtype=np.uint8)\n",
|
||
" y_train = np.fromfile('./stl10_binary/' + prefix + '_y.bin', dtype=np.uint8)\n",
|
||
"\n",
|
||
" X_train = np.reshape(X_train, (-1, 3, 96, 96)) # CWH\n",
|
||
" X_train = np.transpose(X_train, (0, 1, 3, 2)) # CHW\n",
|
||
"\n",
|
||
" print(\"{} images\".format(prefix))\n",
|
||
" print(X_train.shape)\n",
|
||
" print(y_train.shape)\n",
|
||
" return X_train, y_train - 1"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "Xn0xslbELjQq",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"# load STL-10 train data\n",
|
||
"X_train, y_train = _load_stl10(\"train\")"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "7shAS6fvXtPG",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"fig, axs = plt.subplots(nrows=2, ncols=6, constrained_layout=False, figsize=(12,4))\n",
|
||
"\n",
|
||
"for i, ax in enumerate(axs.flat):\n",
|
||
" ax.imshow(X_train[i].transpose(1,2,0))\n",
|
||
"plt.show()"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "YUJ3_xoPLjQv",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"# load STL-10 test data\n",
|
||
"X_test, y_test = _load_stl10(\"test\")"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "QE8sEe_qLjQz",
|
||
"colab_type": "text"
|
||
},
|
||
"source": [
|
||
"## Test protocol #1 PCA features"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "WFmUZzKoLjQ4",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"scaler = preprocessing.StandardScaler()\n",
|
||
"scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n",
|
||
"\n",
|
||
"pca = PCA(n_components=config['model']['out_dim'])\n",
|
||
"\n",
|
||
"X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n",
|
||
"X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n",
|
||
"\n",
|
||
"print(\"PCA features\")\n",
|
||
"print(X_train_pca.shape)\n",
|
||
"print(X_test_pca.shape)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "Yq2N_FpVLjQ8",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"def linear_model_eval(X_train, y_train, X_test, y_test):\n",
|
||
" \n",
|
||
" clf = LogisticRegression(random_state=0, max_iter=1200, solver='lbfgs', C=1.0)\n",
|
||
" clf.fit(X_train, y_train)\n",
|
||
" print(\"Logistic Regression feature eval\")\n",
|
||
" print(\"Train score:\", clf.score(X_train, y_train))\n",
|
||
" print(\"Test score:\", clf.score(X_test, y_test))\n",
|
||
" \n",
|
||
" print(\"-------------------------------\")\n",
|
||
" neigh = KNeighborsClassifier(n_neighbors=10)\n",
|
||
" neigh.fit(X_train, y_train)\n",
|
||
" print(\"KNN feature eval\")\n",
|
||
" print(\"Train score:\", neigh.score(X_train, y_train))\n",
|
||
" print(\"Test score:\", neigh.score(X_test, y_test))"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "6VTolghbLjRA",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)\n",
|
||
"\n",
|
||
"## clean up resources\n",
|
||
"del X_train_pca\n",
|
||
"del X_test_pca"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "5nf4rDtWLjRE",
|
||
"colab_type": "text"
|
||
},
|
||
"source": [
|
||
"## Protocol #2 Logisitc Regression"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "fYezlvoNVpeT",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"# Load the neural net module\n",
|
||
"spec = importlib.util.spec_from_file_location(\"model\", os.path.join(checkpoints_folder, 'resnet_simclr.py'))\n",
|
||
"resnet_module = importlib.util.module_from_spec(spec)\n",
|
||
"spec.loader.exec_module(resnet_module)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "AxhfD0c7LjRF",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"model = resnet_module.ResNetSimCLR(**config['model'])\n",
|
||
"model.eval()\n",
|
||
"\n",
|
||
"state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n",
|
||
"model.load_state_dict(state_dict)\n",
|
||
"model = model.to(device)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "ro6yG6ngLjRI",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"def next_batch(X, y, batch_size):\n",
|
||
" for i in range(0, X.shape[0], batch_size):\n",
|
||
" X_batch = torch.tensor(X[i: i+batch_size]) / 255.\n",
|
||
" y_batch = torch.tensor(y[i: i+batch_size])\n",
|
||
" yield X_batch.to(device), y_batch.to(device)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "oftbHXcdLjRM",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"X_train_feature = []\n",
|
||
"\n",
|
||
"for batch_x, batch_y in next_batch(X_train, y_train, batch_size=config['batch_size']):\n",
|
||
" features, _ = model(batch_x)\n",
|
||
" X_train_feature.extend(features.cpu().detach().numpy())\n",
|
||
" \n",
|
||
"X_train_feature = np.array(X_train_feature)\n",
|
||
"\n",
|
||
"print(\"Train features\")\n",
|
||
"print(X_train_feature.shape)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "sverVlKPLjRP",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"X_test_feature = []\n",
|
||
"\n",
|
||
"for batch_x, batch_y in next_batch(X_test, y_test, batch_size=config['batch_size']):\n",
|
||
" features, _ = model(batch_x)\n",
|
||
" X_test_feature.extend(features.cpu().detach().numpy())\n",
|
||
" \n",
|
||
"X_test_feature = np.array(X_test_feature)\n",
|
||
"\n",
|
||
"print(\"Test features\")\n",
|
||
"print(X_test_feature.shape)"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "91jHpRQyLjRT",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
"scaler = preprocessing.StandardScaler()\n",
|
||
"scaler.fit(X_train_feature)\n",
|
||
"\n",
|
||
"linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)\n",
|
||
"\n",
|
||
"del X_train_feature\n",
|
||
"del X_test_feature"
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {
|
||
"id": "fXy_YX8_b7gL",
|
||
"colab_type": "code",
|
||
"colab": {}
|
||
},
|
||
"source": [
|
||
""
|
||
],
|
||
"execution_count": 0,
|
||
"outputs": []
|
||
}
|
||
]
|
||
} |