{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "display_name": "pytorch", "language": "python", "name": "pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" }, "colab": { "name": "Copy of mini-batch-logistic-regression-evaluator.ipynb", "provenance": [], "include_colab_link": true }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": {} } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "id": "YUemQib7ZE4D" }, "source": [ "import torch\n", "import sys\n", "import numpy as np\n", "import os\n", "import yaml\n", "import matplotlib.pyplot as plt\n", "import torchvision" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "WSgRE1CcLqdS", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e44ac358-6480-4a5f-a358-6eb6ace26c8b" }, "source": [ "!pip install gdown" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n", "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.15.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.23.0)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.41.1)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.12.5)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.10)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "NOIJEui1ZziV" }, "source": [ "def get_file_id_by_model(folder_name):\n", " file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',\n", " 'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C'}\n", " return file_id.get(folder_name, \"Model not found.\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "G7YMxsvEZMrX", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "36932a7d-c7e5-492a-f37d-8be6b18f787a" }, "source": [ "folder_name = 'resnet18_100-epochs_stl10'\n", "file_id = get_file_id_by_model(folder_name)\n", "print(folder_name, file_id)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "resnet18_100-epochs_stl10 14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "PWZ8fet_YoJm", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "8d52756d-707b-4a3f-9e8c-0d191408deab" }, "source": [ "# download and extract model files\n", "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n", "os.system('unzip {}'.format(folder_name))\n", "!ls" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "checkpoint_0100.pth.tar\n", "config.yml\n", "events.out.tfevents.1610901470.4cb2c837708d.2683858.0\n", "resnet18_100-epochs_stl10.zip\n", "sample_data\n", "training.log\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "ooyhd8piZ1w1", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "6ffb73aa-35c5-4df2-bd1f-6de6a235a9e5" }, "source": [ "!unzip resnet18_100-epochs_stl10" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Archive: resnet18_100-epochs_stl10.zip\n", "replace checkpoint_0100.pth.tar? [y]es, [n]o, [A]ll, [N]one, [r]ename: A\n", " inflating: checkpoint_0100.pth.tar \n", " inflating: config.yml \n", " inflating: events.out.tfevents.1610901470.4cb2c837708d.2683858.0 \n", " inflating: training.log \n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "3_nypQVEv-hn" }, "source": [ "from torch.utils.data import DataLoader\n", "import torchvision.transforms as transforms\n", "from torchvision import datasets" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "lDfbL3w_Z0Od", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "5f58bd9b-4428-4b8c-e271-b47ca6694f34" }, "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(\"Using device:\", device)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Using device: cuda\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "BfIPl0G6_RrT" }, "source": [ "def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n", " train_dataset = datasets.STL10('./data', split='train', download=download,\n", " transform=transforms.ToTensor())\n", "\n", " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n", " num_workers=0, drop_last=False, shuffle=shuffle)\n", " \n", " test_dataset = datasets.STL10('./data', split='test', download=download,\n", " transform=transforms.ToTensor())\n", "\n", " test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n", " num_workers=10, drop_last=False, shuffle=shuffle)\n", " return train_loader, test_loader\n", "\n", "def get_cifar10_data_loaders(download, shuffle=False, batch_size=128):\n", " train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n", " transform=transforms.ToTensor())\n", "\n", " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n", " num_workers=0, drop_last=False, shuffle=shuffle)\n", " \n", " test_dataset = datasets.CIFAR10('./data', train=False, download=download,\n", " transform=transforms.ToTensor())\n", "\n", " test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n", " num_workers=10, drop_last=False, shuffle=shuffle)\n", " return train_loader, test_loader" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "6N8lYkbmDTaK" }, "source": [ "with open(os.path.join('./config.yml')) as file:\n", " config = yaml.load(file)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "a18lPD-tIle6" }, "source": [ "if config.arch == 'resnet18':\n", " model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)\n", "elif config.arch == 'resnet50':\n", " model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)" ], "execution_count": 11, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "4AIfgq41GuTT" }, "source": [ "checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)\n", "state_dict = checkpoint['state_dict']\n", "\n", "for k in list(state_dict.keys()):\n", "\n", " if k.startswith('backbone.'):\n", " if k.startswith('backbone') and not k.startswith('backbone.fc'):\n", " # remove prefix\n", " state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n", " del state_dict[k]" ], "execution_count": 12, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "VVjA83PPJYWl" }, "source": [ "log = model.load_state_dict(state_dict, strict=False)\n", "assert log.missing_keys == ['fc.weight', 'fc.bias']" ], "execution_count": 13, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "_GC0a14uWRr6", "colab": { "base_uri": "https://localhost:8080/", "height": 102, "referenced_widgets": [ "48ebf2f69d1f4f5a9208cd2923eb5eac" ] }, "outputId": "6c3b86ad-b568-4c68-c1fb-1f7b2abbb6aa" }, "source": [ "if config.dataset_name == 'cifar10':\n", " train_loader, test_loader = get_cifar10_data_loaders(download=True)\n", "elif config.dataset_name == 'stl10':\n", " train_loader, test_loader = get_stl10_data_loaders(download=True)\n", "print(\"Dataset:\", config.dataset_name)" ], "execution_count": 14, "outputs": [ { "output_type": "stream", "text": [ "Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "48ebf2f69d1f4f5a9208cd2923eb5eac", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "Extracting ./data/stl10_binary.tar.gz to ./data\n", "Files already downloaded and verified\n", "Dataset: stl10\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "pYT_KsM0Mnnr" }, "source": [ "# freeze all layers but the last fc\n", "for name, param in model.named_parameters():\n", " if name not in ['fc.weight', 'fc.bias']:\n", " param.requires_grad = False\n", "\n", "parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n", "assert len(parameters) == 2 # fc.weight, fc.bias" ], "execution_count": 15, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "aPVh1S_eMRDU" }, "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n", "criterion = torch.nn.CrossEntropyLoss().to(device)" ], "execution_count": 16, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "edr6RhP2PdVq" }, "source": [ "def accuracy(output, target, topk=(1,)):\n", " \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n", " with torch.no_grad():\n", " maxk = max(topk)\n", " batch_size = target.size(0)\n", "\n", " _, pred = output.topk(maxk, 1, True, True)\n", " pred = pred.t()\n", " correct = pred.eq(target.view(1, -1).expand_as(pred))\n", "\n", " res = []\n", " for k in topk:\n", " correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n", " res.append(correct_k.mul_(100.0 / batch_size))\n", " return res" ], "execution_count": 17, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "qOder0dAMI7X", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "d6127d8e-836f-4e69-a344-fee7e836d63a" }, "source": [ "epochs = 100\n", "for epoch in range(epochs):\n", " top1_train_accuracy = 0\n", " for counter, (x_batch, y_batch) in enumerate(train_loader):\n", " x_batch = x_batch.to(device)\n", " y_batch = y_batch.to(device)\n", "\n", " logits = model(x_batch)\n", " loss = criterion(logits, y_batch)\n", " top1 = accuracy(logits, y_batch, topk=(1,))\n", " top1_train_accuracy += top1[0]\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " top1_train_accuracy /= (counter + 1)\n", " top1_accuracy = 0\n", " top5_accuracy = 0\n", " for counter, (x_batch, y_batch) in enumerate(test_loader):\n", " x_batch = x_batch.to(device)\n", " y_batch = y_batch.to(device)\n", "\n", " logits = model(x_batch)\n", " \n", " top1, top5 = accuracy(logits, y_batch, topk=(1,5))\n", " top1_accuracy += top1[0]\n", " top5_accuracy += top5[0]\n", " \n", " top1_accuracy /= (counter + 1)\n", " top5_accuracy /= (counter + 1)\n", " print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")" ], "execution_count": 18, "outputs": [ { "output_type": "stream", "text": [ "Epoch 0\tTop1 Train accuracy 27.890625\tTop1 Test accuracy: 42.05322265625\tTop5 test acc: 93.29833984375\n", "Epoch 1\tTop1 Train accuracy 49.921875\tTop1 Test accuracy: 54.45556640625\tTop5 test acc: 96.1181640625\n", "Epoch 2\tTop1 Train accuracy 57.3828125\tTop1 Test accuracy: 58.9599609375\tTop5 test acc: 96.9482421875\n", "Epoch 3\tTop1 Train accuracy 60.01953125\tTop1 Test accuracy: 60.38818359375\tTop5 test acc: 97.03369140625\n", "Epoch 4\tTop1 Train accuracy 61.7578125\tTop1 Test accuracy: 61.572265625\tTop5 test acc: 97.1923828125\n", "Epoch 5\tTop1 Train accuracy 62.91015625\tTop1 Test accuracy: 62.21923828125\tTop5 test acc: 97.30224609375\n", "Epoch 6\tTop1 Train accuracy 63.57421875\tTop1 Test accuracy: 62.6220703125\tTop5 test acc: 97.4365234375\n", "Epoch 7\tTop1 Train accuracy 64.12109375\tTop1 Test accuracy: 63.18359375\tTop5 test acc: 97.55859375\n", "Epoch 8\tTop1 Train accuracy 64.82421875\tTop1 Test accuracy: 63.51318359375\tTop5 test acc: 97.57080078125\n", "Epoch 9\tTop1 Train accuracy 65.17578125\tTop1 Test accuracy: 63.80615234375\tTop5 test acc: 97.59521484375\n", "Epoch 10\tTop1 Train accuracy 65.5859375\tTop1 Test accuracy: 64.14794921875\tTop5 test acc: 97.6318359375\n", "Epoch 11\tTop1 Train accuracy 65.80078125\tTop1 Test accuracy: 64.51416015625\tTop5 test acc: 97.61962890625\n", "Epoch 12\tTop1 Train accuracy 66.03515625\tTop1 Test accuracy: 64.70947265625\tTop5 test acc: 97.69287109375\n", "Epoch 13\tTop1 Train accuracy 66.42578125\tTop1 Test accuracy: 64.88037109375\tTop5 test acc: 97.705078125\n", "Epoch 14\tTop1 Train accuracy 66.9140625\tTop1 Test accuracy: 65.07568359375\tTop5 test acc: 97.76611328125\n", "Epoch 15\tTop1 Train accuracy 67.265625\tTop1 Test accuracy: 65.24658203125\tTop5 test acc: 97.81494140625\n", "Epoch 16\tTop1 Train accuracy 67.48046875\tTop1 Test accuracy: 65.46630859375\tTop5 test acc: 97.8515625\n", "Epoch 17\tTop1 Train accuracy 67.6171875\tTop1 Test accuracy: 65.71044921875\tTop5 test acc: 97.86376953125\n", "Epoch 18\tTop1 Train accuracy 67.83203125\tTop1 Test accuracy: 65.966796875\tTop5 test acc: 97.8759765625\n", "Epoch 19\tTop1 Train accuracy 68.0078125\tTop1 Test accuracy: 66.05224609375\tTop5 test acc: 97.88818359375\n", "Epoch 20\tTop1 Train accuracy 68.1640625\tTop1 Test accuracy: 66.17431640625\tTop5 test acc: 97.88818359375\n", "Epoch 21\tTop1 Train accuracy 68.37890625\tTop1 Test accuracy: 66.30859375\tTop5 test acc: 97.900390625\n", "Epoch 22\tTop1 Train accuracy 68.49609375\tTop1 Test accuracy: 66.50390625\tTop5 test acc: 97.88818359375\n", "Epoch 23\tTop1 Train accuracy 68.75\tTop1 Test accuracy: 66.6259765625\tTop5 test acc: 97.91259765625\n", "Epoch 24\tTop1 Train accuracy 68.90625\tTop1 Test accuracy: 66.68701171875\tTop5 test acc: 97.96142578125\n", "Epoch 25\tTop1 Train accuracy 68.984375\tTop1 Test accuracy: 66.8212890625\tTop5 test acc: 97.998046875\n", "Epoch 26\tTop1 Train accuracy 69.39453125\tTop1 Test accuracy: 66.9677734375\tTop5 test acc: 98.0224609375\n", "Epoch 27\tTop1 Train accuracy 69.4921875\tTop1 Test accuracy: 67.1142578125\tTop5 test acc: 98.01025390625\n", "Epoch 28\tTop1 Train accuracy 69.6484375\tTop1 Test accuracy: 67.1630859375\tTop5 test acc: 98.0224609375\n", "Epoch 29\tTop1 Train accuracy 69.7265625\tTop1 Test accuracy: 67.19970703125\tTop5 test acc: 98.03466796875\n", "Epoch 30\tTop1 Train accuracy 69.74609375\tTop1 Test accuracy: 67.24853515625\tTop5 test acc: 98.05908203125\n", "Epoch 31\tTop1 Train accuracy 69.921875\tTop1 Test accuracy: 67.37060546875\tTop5 test acc: 98.03466796875\n", "Epoch 32\tTop1 Train accuracy 70.078125\tTop1 Test accuracy: 67.46826171875\tTop5 test acc: 98.03466796875\n", "Epoch 33\tTop1 Train accuracy 70.25390625\tTop1 Test accuracy: 67.5048828125\tTop5 test acc: 98.0712890625\n", "Epoch 34\tTop1 Train accuracy 70.33203125\tTop1 Test accuracy: 67.59033203125\tTop5 test acc: 98.095703125\n", "Epoch 35\tTop1 Train accuracy 70.48828125\tTop1 Test accuracy: 67.73681640625\tTop5 test acc: 98.13232421875\n", "Epoch 36\tTop1 Train accuracy 70.5859375\tTop1 Test accuracy: 67.83447265625\tTop5 test acc: 98.1201171875\n", "Epoch 37\tTop1 Train accuracy 70.625\tTop1 Test accuracy: 67.85888671875\tTop5 test acc: 98.13232421875\n", "Epoch 38\tTop1 Train accuracy 70.78125\tTop1 Test accuracy: 67.88330078125\tTop5 test acc: 98.13232421875\n", "Epoch 39\tTop1 Train accuracy 70.91796875\tTop1 Test accuracy: 67.919921875\tTop5 test acc: 98.10791015625\n", "Epoch 40\tTop1 Train accuracy 70.95703125\tTop1 Test accuracy: 67.95654296875\tTop5 test acc: 98.10791015625\n", "Epoch 41\tTop1 Train accuracy 71.03515625\tTop1 Test accuracy: 68.00537109375\tTop5 test acc: 98.1201171875\n", "Epoch 42\tTop1 Train accuracy 71.07421875\tTop1 Test accuracy: 68.06640625\tTop5 test acc: 98.15673828125\n", "Epoch 43\tTop1 Train accuracy 71.15234375\tTop1 Test accuracy: 68.12744140625\tTop5 test acc: 98.15673828125\n", "Epoch 44\tTop1 Train accuracy 71.2109375\tTop1 Test accuracy: 68.1396484375\tTop5 test acc: 98.1689453125\n", "Epoch 45\tTop1 Train accuracy 71.25\tTop1 Test accuracy: 68.1396484375\tTop5 test acc: 98.1689453125\n", "Epoch 46\tTop1 Train accuracy 71.46484375\tTop1 Test accuracy: 68.15185546875\tTop5 test acc: 98.193359375\n", "Epoch 47\tTop1 Train accuracy 71.58203125\tTop1 Test accuracy: 68.22509765625\tTop5 test acc: 98.2177734375\n", "Epoch 48\tTop1 Train accuracy 71.6796875\tTop1 Test accuracy: 68.27392578125\tTop5 test acc: 98.22998046875\n", "Epoch 49\tTop1 Train accuracy 71.8359375\tTop1 Test accuracy: 68.3349609375\tTop5 test acc: 98.22998046875\n", "Epoch 50\tTop1 Train accuracy 71.93359375\tTop1 Test accuracy: 68.44482421875\tTop5 test acc: 98.2421875\n", "Epoch 51\tTop1 Train accuracy 72.01171875\tTop1 Test accuracy: 68.4814453125\tTop5 test acc: 98.2177734375\n", "Epoch 52\tTop1 Train accuracy 72.0703125\tTop1 Test accuracy: 68.505859375\tTop5 test acc: 98.2177734375\n", "Epoch 53\tTop1 Train accuracy 72.2265625\tTop1 Test accuracy: 68.54248046875\tTop5 test acc: 98.22998046875\n", "Epoch 54\tTop1 Train accuracy 72.24609375\tTop1 Test accuracy: 68.5791015625\tTop5 test acc: 98.22998046875\n", "Epoch 55\tTop1 Train accuracy 72.34375\tTop1 Test accuracy: 68.65234375\tTop5 test acc: 98.25439453125\n", "Epoch 56\tTop1 Train accuracy 72.421875\tTop1 Test accuracy: 68.71337890625\tTop5 test acc: 98.3154296875\n", "Epoch 57\tTop1 Train accuracy 72.51953125\tTop1 Test accuracy: 68.71337890625\tTop5 test acc: 98.3154296875\n", "Epoch 58\tTop1 Train accuracy 72.94921875\tTop1 Test accuracy: 68.76220703125\tTop5 test acc: 98.3154296875\n", "Epoch 59\tTop1 Train accuracy 72.98828125\tTop1 Test accuracy: 68.83544921875\tTop5 test acc: 98.3154296875\n", "Epoch 60\tTop1 Train accuracy 73.0859375\tTop1 Test accuracy: 68.88427734375\tTop5 test acc: 98.30322265625\n", "Epoch 61\tTop1 Train accuracy 73.18359375\tTop1 Test accuracy: 68.896484375\tTop5 test acc: 98.32763671875\n", "Epoch 62\tTop1 Train accuracy 73.3984375\tTop1 Test accuracy: 68.88427734375\tTop5 test acc: 98.33984375\n", "Epoch 63\tTop1 Train accuracy 73.4375\tTop1 Test accuracy: 68.95751953125\tTop5 test acc: 98.33984375\n", "Epoch 64\tTop1 Train accuracy 73.515625\tTop1 Test accuracy: 68.994140625\tTop5 test acc: 98.32763671875\n", "Epoch 65\tTop1 Train accuracy 73.57421875\tTop1 Test accuracy: 68.9697265625\tTop5 test acc: 98.3154296875\n", "Epoch 66\tTop1 Train accuracy 73.61328125\tTop1 Test accuracy: 69.03076171875\tTop5 test acc: 98.32763671875\n", "Epoch 67\tTop1 Train accuracy 73.671875\tTop1 Test accuracy: 69.07958984375\tTop5 test acc: 98.3154296875\n", "Epoch 68\tTop1 Train accuracy 73.7109375\tTop1 Test accuracy: 69.12841796875\tTop5 test acc: 98.3154296875\n", "Epoch 69\tTop1 Train accuracy 73.8671875\tTop1 Test accuracy: 69.20166015625\tTop5 test acc: 98.3154296875\n", "Epoch 70\tTop1 Train accuracy 73.984375\tTop1 Test accuracy: 69.25048828125\tTop5 test acc: 98.33984375\n", "Epoch 71\tTop1 Train accuracy 74.00390625\tTop1 Test accuracy: 69.2626953125\tTop5 test acc: 98.35205078125\n", "Epoch 72\tTop1 Train accuracy 74.00390625\tTop1 Test accuracy: 69.3115234375\tTop5 test acc: 98.33984375\n", "Epoch 73\tTop1 Train accuracy 74.0234375\tTop1 Test accuracy: 69.34814453125\tTop5 test acc: 98.35205078125\n", "Epoch 74\tTop1 Train accuracy 74.140625\tTop1 Test accuracy: 69.37255859375\tTop5 test acc: 98.33984375\n", "Epoch 75\tTop1 Train accuracy 74.23828125\tTop1 Test accuracy: 69.4091796875\tTop5 test acc: 98.35205078125\n", "Epoch 76\tTop1 Train accuracy 74.31640625\tTop1 Test accuracy: 69.4091796875\tTop5 test acc: 98.37646484375\n", "Epoch 77\tTop1 Train accuracy 74.43359375\tTop1 Test accuracy: 69.4091796875\tTop5 test acc: 98.3642578125\n", "Epoch 78\tTop1 Train accuracy 74.55078125\tTop1 Test accuracy: 69.3603515625\tTop5 test acc: 98.3642578125\n", "Epoch 79\tTop1 Train accuracy 74.58984375\tTop1 Test accuracy: 69.37255859375\tTop5 test acc: 98.3642578125\n", "Epoch 80\tTop1 Train accuracy 74.609375\tTop1 Test accuracy: 69.42138671875\tTop5 test acc: 98.3642578125\n", "Epoch 81\tTop1 Train accuracy 74.6484375\tTop1 Test accuracy: 69.49462890625\tTop5 test acc: 98.3642578125\n", "Epoch 82\tTop1 Train accuracy 74.6875\tTop1 Test accuracy: 69.47021484375\tTop5 test acc: 98.35205078125\n", "Epoch 83\tTop1 Train accuracy 74.7265625\tTop1 Test accuracy: 69.5556640625\tTop5 test acc: 98.35205078125\n", "Epoch 84\tTop1 Train accuracy 74.78515625\tTop1 Test accuracy: 69.59228515625\tTop5 test acc: 98.35205078125\n", "Epoch 85\tTop1 Train accuracy 74.8828125\tTop1 Test accuracy: 69.6533203125\tTop5 test acc: 98.35205078125\n", "Epoch 86\tTop1 Train accuracy 74.94140625\tTop1 Test accuracy: 69.677734375\tTop5 test acc: 98.3642578125\n", "Epoch 87\tTop1 Train accuracy 75.0390625\tTop1 Test accuracy: 69.7509765625\tTop5 test acc: 98.35205078125\n", "Epoch 88\tTop1 Train accuracy 75.0390625\tTop1 Test accuracy: 69.71435546875\tTop5 test acc: 98.35205078125\n", "Epoch 89\tTop1 Train accuracy 75.1171875\tTop1 Test accuracy: 69.775390625\tTop5 test acc: 98.33984375\n", "Epoch 90\tTop1 Train accuracy 75.21484375\tTop1 Test accuracy: 69.7509765625\tTop5 test acc: 98.33984375\n", "Epoch 91\tTop1 Train accuracy 75.25390625\tTop1 Test accuracy: 69.82421875\tTop5 test acc: 98.32763671875\n", "Epoch 92\tTop1 Train accuracy 75.29296875\tTop1 Test accuracy: 69.86083984375\tTop5 test acc: 98.33984375\n", "Epoch 93\tTop1 Train accuracy 75.33203125\tTop1 Test accuracy: 69.88525390625\tTop5 test acc: 98.35205078125\n", "Epoch 94\tTop1 Train accuracy 75.37109375\tTop1 Test accuracy: 69.81201171875\tTop5 test acc: 98.3642578125\n", "Epoch 95\tTop1 Train accuracy 75.37109375\tTop1 Test accuracy: 69.83642578125\tTop5 test acc: 98.37646484375\n", "Epoch 96\tTop1 Train accuracy 75.37109375\tTop1 Test accuracy: 69.83642578125\tTop5 test acc: 98.37646484375\n", "Epoch 97\tTop1 Train accuracy 75.41015625\tTop1 Test accuracy: 69.86083984375\tTop5 test acc: 98.37646484375\n", "Epoch 98\tTop1 Train accuracy 75.41015625\tTop1 Test accuracy: 69.90966796875\tTop5 test acc: 98.37646484375\n", "Epoch 99\tTop1 Train accuracy 75.46875\tTop1 Test accuracy: 69.921875\tTop5 test acc: 98.37646484375\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "dtYqHZirMNZk" }, "source": [ "" ], "execution_count": 18, "outputs": [] } ] }