{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import sys\n", "sys.path.insert(1, '../')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from models.resnet_simclr import ResNetSimCLR\n", "import torchvision.transforms as transforms\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "import numpy as np\n", "import os\n", "from sklearn.neighbors import KNeighborsClassifier\n", "import yaml" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "folder_name = 'Mar10_21-50-05_thallessilva'\n", "checkpoints_folder = os.path.join('../runs', folder_name, 'checkpoints')\n", "config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"), Loader=yaml.FullLoader)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'batch_size': 512,\n", " 'out_dim': 256,\n", " 's': 1,\n", " 'temperature': 0.5,\n", " 'base_convnet': 'resnet18',\n", " 'use_cosine_similarity': True,\n", " 'epochs': 50,\n", " 'num_workers': 4,\n", " 'valid_size': 0.05,\n", " 'eval_every_n_epochs': 2}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def _load_stl10(prefix=\"train\"):\n", " X_train = np.fromfile('../data/stl10_binary/' + prefix + '_X.bin', dtype=np.uint8)\n", " y_train = np.fromfile('../data/stl10_binary/' + prefix + '_y.bin', dtype=np.uint8)\n", "\n", " X_train = np.reshape(X_train, (-1, 3, 96, 96))\n", " X_train = np.transpose(X_train, (0, 3, 2, 1))\n", " print(\"{} images\".format(prefix))\n", " print(X_train.shape)\n", " print(y_train.shape)\n", " return X_train, y_train - 1" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train images\n", "(5000, 96, 96, 3)\n", "(5000,)\n" ] } ], "source": [ "# load STL-10 train data\n", "X_train, y_train = _load_stl10(\"train\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test images\n", "(8000, 96, 96, 3)\n", "(8000,)\n" ] } ], "source": [ "# load STL-10 test data\n", "X_test, y_test = _load_stl10(\"test\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test protocol #1 PCA features" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn import preprocessing" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PCA features\n", "(5000, 256)\n", "(8000, 256)\n" ] } ], "source": [ "scaler = preprocessing.StandardScaler()\n", "scaler.fit(X_train.reshape((X_train.shape[0],-1)))\n", "\n", "pca = PCA(n_components=config['out_dim'])\n", "\n", "X_train_pca = pca.fit_transform(scaler.transform(X_train.reshape(X_train.shape[0], -1)))\n", "X_test_pca = pca.transform(scaler.transform(X_test.reshape(X_test.shape[0], -1)))\n", "\n", "print(\"PCA features\")\n", "print(X_train_pca.shape)\n", "print(X_test_pca.shape)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def linear_model_eval(X_train, y_train, X_test, y_test):\n", " \n", " clf = LogisticRegression(random_state=0, max_iter=1000, solver='lbfgs', C=1.0)\n", " clf.fit(X_train, y_train)\n", " print(\"Logistic Regression feature eval\")\n", " print(\"Train score:\", clf.score(X_train, y_train))\n", " print(\"Test score:\", clf.score(X_test, y_test))\n", " \n", " print(\"-------------------------------\")\n", " neigh = KNeighborsClassifier(n_neighbors=10)\n", " neigh.fit(X_train, y_train)\n", " print(\"KNN feature eval\")\n", " print(\"Train score:\", neigh.score(X_train, y_train))\n", " print(\"Test score:\", neigh.score(X_test, y_test))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Logistic Regression feature eval\n", "Train score: 0.4966\n", "Test score: 0.35\n", "-------------------------------\n", "KNN feature eval\n", "Train score: 0.4036\n", "Test score: 0.300125\n" ] } ], "source": [ "linear_model_eval(X_train_pca, y_train, X_test_pca, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Protocol #2 Logisitc Regression" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = ResNetSimCLR(out_dim=config['out_dim'])\n", "model.eval()\n", "\n", "state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))\n", "model.load_state_dict(state_dict)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def next_batch(X, y, batch_size):\n", " for i in range(0, X.shape[0], batch_size):\n", " X_batch = torch.tensor(X[i: i+batch_size]) / 255.\n", " y_batch = torch.tensor(y[i: i+batch_size])\n", " yield X_batch.permute((0,3,1,2)), y_batch" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train features\n", "(5000, 512)\n" ] } ], "source": [ "X_train_feature = []\n", "\n", "for batch_x, batch_y in next_batch(X_train, y_train, batch_size=128):\n", " features, _ = model(batch_x)\n", " X_train_feature.extend(features.detach().numpy())\n", " \n", "X_train_feature = np.array(X_train_feature)\n", "\n", "print(\"Train features\")\n", "print(X_train_feature.shape)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test features\n", "(8000, 512)\n" ] } ], "source": [ "X_test_feature = []\n", "\n", "for batch_x, batch_y in next_batch(X_test, y_test, batch_size=256):\n", " features, _ = model(batch_x)\n", " X_test_feature.extend(features.detach().numpy())\n", " \n", "X_test_feature = np.array(X_test_feature)\n", "\n", "print(\"Test features\")\n", "print(X_test_feature.shape)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Logistic Regression feature eval\n", "Train score: 0.9628\n", "Test score: 0.75\n", "-------------------------------\n", "KNN feature eval\n", "Train score: 0.7764\n", "Test score: 0.709125\n" ] } ], "source": [ "scaler = preprocessing.StandardScaler()\n", "scaler.fit(X_train_feature)\n", "\n", "linear_model_eval(scaler.transform(X_train_feature), y_train, scaler.transform(X_test_feature), y_test)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# SimCLR feature evaluation\n", "# Train score: 0.8966\n", "# Test score: 0.634125" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pytorch", "language": "python", "name": "pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }