{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from model import Encoder, ResNet18\n", "import torchvision.transforms as transforms\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "batch_size = 256\n", "out_dim = 64" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training images\n", "(5000, 96, 96, 3)\n", "(5000,)\n" ] } ], "source": [ "X_train = np.fromfile('data/stl10_binary/train_X.bin', dtype=np.uint8)\n", "y_train = np.fromfile('data/stl10_binary/train_y.bin', dtype=np.uint8)\n", "\n", "X_train = np.reshape(X_train, (-1, 3, 96, 96))\n", "X_train = np.transpose(X_train, (0, 3, 2, 1))\n", "print(\"Training images\")\n", "print(X_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test images\n", "(8000, 96, 96, 3)\n", "(8000,)\n" ] } ], "source": [ "X_test = np.fromfile('data/stl10_binary/test_X.bin', dtype=np.uint8)\n", "y_test = np.fromfile('data/stl10_binary/test_y.bin', dtype=np.uint8)\n", "\n", "X_test = np.reshape(X_test, (-1, 3, 96, 96))\n", "X_test = np.transpose(X_test, (0, 3, 2, 1))\n", "print(\"Test images\")\n", "print(X_test.shape)\n", "print(y_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test protocol #1 PCA features" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "from sklearn.linear_model import LogisticRegression" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PCA features\n", "(5000, 64)\n", "(8000, 64)\n" ] } ], "source": [ "pca = PCA(n_components=out_dim)\n", "X_train_pca = pca.fit_transform(X_train.reshape((X_train.shape[0],-1)))\n", "X_test_pca = pca.transform(X_test.reshape((X_test.shape[0],-1)))\n", "\n", "print(\"PCA features\")\n", "print(X_train_pca.shape)\n", "print(X_test_pca.shape)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PCA feature evaluation\n", "Train score: 0.3984\n", "Test score: 0.353125\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" ] } ], "source": [ "clf = LogisticRegression(random_state=0).fit(X_train_pca, y_train)\n", "print(\"PCA feature evaluation\")\n", "print(\"Train score:\", clf.score(X_train_pca, y_train))\n", "print(\"Test score:\", clf.score(X_test_pca, y_test))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "data_augment = transforms.Compose([transforms.RandomResizedCrop(96),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.STL10('data', split='train', download=True, transform=data_augment)\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=False, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "test_dataset = datasets.STL10('data', split='test', download=True, transform=data_augment)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=1, drop_last=False, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ResNet18(\n", " (features): Sequential(\n", " (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU(inplace=True)\n", " (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (4): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (5): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (6): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (7): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (8): AdaptiveAvgPool2d(output_size=(1, 1))\n", " )\n", " (l1): Linear(in_features=512, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=64, bias=True)\n", ")\n", "odict_keys(['features.0.weight', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.4.0.conv1.weight', 'features.4.0.bn1.weight', 'features.4.0.bn1.bias', 'features.4.0.bn1.running_mean', 'features.4.0.bn1.running_var', 'features.4.0.bn1.num_batches_tracked', 'features.4.0.conv2.weight', 'features.4.0.bn2.weight', 'features.4.0.bn2.bias', 'features.4.0.bn2.running_mean', 'features.4.0.bn2.running_var', 'features.4.0.bn2.num_batches_tracked', 'features.4.1.conv1.weight', 'features.4.1.bn1.weight', 'features.4.1.bn1.bias', 'features.4.1.bn1.running_mean', 'features.4.1.bn1.running_var', 'features.4.1.bn1.num_batches_tracked', 'features.4.1.conv2.weight', 'features.4.1.bn2.weight', 'features.4.1.bn2.bias', 'features.4.1.bn2.running_mean', 'features.4.1.bn2.running_var', 'features.4.1.bn2.num_batches_tracked', 'features.5.0.conv1.weight', 'features.5.0.bn1.weight', 'features.5.0.bn1.bias', 'features.5.0.bn1.running_mean', 'features.5.0.bn1.running_var', 'features.5.0.bn1.num_batches_tracked', 'features.5.0.conv2.weight', 'features.5.0.bn2.weight', 'features.5.0.bn2.bias', 'features.5.0.bn2.running_mean', 'features.5.0.bn2.running_var', 'features.5.0.bn2.num_batches_tracked', 'features.5.0.downsample.0.weight', 'features.5.0.downsample.1.weight', 'features.5.0.downsample.1.bias', 'features.5.0.downsample.1.running_mean', 'features.5.0.downsample.1.running_var', 'features.5.0.downsample.1.num_batches_tracked', 'features.5.1.conv1.weight', 'features.5.1.bn1.weight', 'features.5.1.bn1.bias', 'features.5.1.bn1.running_mean', 'features.5.1.bn1.running_var', 'features.5.1.bn1.num_batches_tracked', 'features.5.1.conv2.weight', 'features.5.1.bn2.weight', 'features.5.1.bn2.bias', 'features.5.1.bn2.running_mean', 'features.5.1.bn2.running_var', 'features.5.1.bn2.num_batches_tracked', 'features.6.0.conv1.weight', 'features.6.0.bn1.weight', 'features.6.0.bn1.bias', 'features.6.0.bn1.running_mean', 'features.6.0.bn1.running_var', 'features.6.0.bn1.num_batches_tracked', 'features.6.0.conv2.weight', 'features.6.0.bn2.weight', 'features.6.0.bn2.bias', 'features.6.0.bn2.running_mean', 'features.6.0.bn2.running_var', 'features.6.0.bn2.num_batches_tracked', 'features.6.0.downsample.0.weight', 'features.6.0.downsample.1.weight', 'features.6.0.downsample.1.bias', 'features.6.0.downsample.1.running_mean', 'features.6.0.downsample.1.running_var', 'features.6.0.downsample.1.num_batches_tracked', 'features.6.1.conv1.weight', 'features.6.1.bn1.weight', 'features.6.1.bn1.bias', 'features.6.1.bn1.running_mean', 'features.6.1.bn1.running_var', 'features.6.1.bn1.num_batches_tracked', 'features.6.1.conv2.weight', 'features.6.1.bn2.weight', 'features.6.1.bn2.bias', 'features.6.1.bn2.running_mean', 'features.6.1.bn2.running_var', 'features.6.1.bn2.num_batches_tracked', 'features.7.0.conv1.weight', 'features.7.0.bn1.weight', 'features.7.0.bn1.bias', 'features.7.0.bn1.running_mean', 'features.7.0.bn1.running_var', 'features.7.0.bn1.num_batches_tracked', 'features.7.0.conv2.weight', 'features.7.0.bn2.weight', 'features.7.0.bn2.bias', 'features.7.0.bn2.running_mean', 'features.7.0.bn2.running_var', 'features.7.0.bn2.num_batches_tracked', 'features.7.0.downsample.0.weight', 'features.7.0.downsample.1.weight', 'features.7.0.downsample.1.bias', 'features.7.0.downsample.1.running_mean', 'features.7.0.downsample.1.running_var', 'features.7.0.downsample.1.num_batches_tracked', 'features.7.1.conv1.weight', 'features.7.1.bn1.weight', 'features.7.1.bn1.bias', 'features.7.1.bn1.running_mean', 'features.7.1.bn1.running_var', 'features.7.1.bn1.num_batches_tracked', 'features.7.1.conv2.weight', 'features.7.1.bn2.weight', 'features.7.1.bn2.bias', 'features.7.1.bn2.running_mean', 'features.7.1.bn2.running_var', 'features.7.1.bn2.num_batches_tracked', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = ResNet18(out_dim=out_dim)\n", "model.eval()\n", "print(model)\n", "\n", "state_dict = torch.load('model/checkpoint.pth')\n", "print(state_dict.keys())\n", "\n", "model.load_state_dict(state_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Protocol #2 Linear separability evaluation" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train features\n", "(5000, 512)\n" ] } ], "source": [ "X_train_feature = []\n", "\n", "for step, (batch_x, batch_y) in enumerate(train_loader):\n", " features, _ = model(batch_x)\n", " X_train_feature.extend(features.detach().numpy())\n", " \n", "X_train_feature = np.array(X_train_feature)\n", "\n", "print(\"Train features\")\n", "print(X_train_feature.shape)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test features\n", "(8000, 512)\n" ] } ], "source": [ "X_test_feature = []\n", "\n", "for step, (batch_x, batch_y) in enumerate(test_loader):\n", " features, _ = model(batch_x)\n", " X_test_feature.extend(features.detach().numpy())\n", " \n", "X_test_feature = np.array(X_test_feature)\n", "\n", "print(\"Test features\")\n", "print(X_test_feature.shape)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" ] } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "clf = LogisticRegression(random_state=0).fit(X_train_feature, y_train)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SimCLR feature evaluation\n", "Train score: 0.7444\n", "Test score: 0.62625\n" ] } ], "source": [ "print(\"SimCLR feature evaluation\")\n", "print(\"Train score:\", clf.score(X_train_feature, y_train))\n", "print(\"Test score:\", clf.score(X_test_feature, y_test))\n", "# SimCLR feature evaluation\n", "# Train score: 0.5298\n", "# Test score: 0.52075" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pytorch", "language": "python", "name": "pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 2 }