SimCLR/feature_eval/mini_batch_logistic_regress...

821 lines
36 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"colab": {
"name": "Copy of mini-batch-logistic-regression-evaluator.ipynb",
"provenance": [],
"include_colab_link": true
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"149b9ce8fb68473a837a77431c12281a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_88cd3db2831e4c13a4a634709700d6b2",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_a88c31d74f5c40a2b24bcff5a35d216c",
"IPY_MODEL_60c6150177694717a622936b830427b5"
]
}
},
"88cd3db2831e4c13a4a634709700d6b2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"a88c31d74f5c40a2b24bcff5a35d216c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_dba019efadee4fdc8c799f309b9a7e70",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "info",
"max": 1,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 1,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_5901c2829a554c8ebbd5926610088041"
}
},
"60c6150177694717a622936b830427b5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_957362a11d174407979cf17012bf9208",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 2640404480/? [00:51<00:00, 32685718.58it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_a4f82234388e4701a02a9f68a177193a"
}
},
"dba019efadee4fdc8c799f309b9a7e70": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"5901c2829a554c8ebbd5926610088041": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"957362a11d174407979cf17012bf9208": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"a4f82234388e4701a02a9f68a177193a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "YUemQib7ZE4D"
},
"source": [
"import torch\n",
"import sys\n",
"import numpy as np\n",
"import os\n",
"import yaml\n",
"import matplotlib.pyplot as plt\n",
"import torchvision"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WSgRE1CcLqdS",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "48a2ae15-f672-495b-8d43-9a23b85fa3b8"
},
"source": [
"!pip install gdown"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.15.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.23.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.41.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.12.5)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.10)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NOIJEui1ZziV"
},
"source": [
"def get_file_id_by_model(folder_name):\n",
" file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',\n",
" 'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',\n",
" 'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}\n",
" return file_id.get(folder_name, \"Model not found.\")"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "G7YMxsvEZMrX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "59475430-69d2-45a2-b61b-ae755d5d6e88"
},
"source": [
"folder_name = 'resnet50_50-epochs_stl10'\n",
"file_id = get_file_id_by_model(folder_name)\n",
"print(folder_name, file_id)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"resnet50_50-epochs_stl10 1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "PWZ8fet_YoJm",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "fbaeb858-221b-4d1b-dd90-001a6e713b75"
},
"source": [
"# download and extract model files\n",
"os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
"os.system('unzip {}'.format(folder_name))\n",
"!ls"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"checkpoint_0040.pth.tar\n",
"config.yml\n",
"events.out.tfevents.1610927742.4cb2c837708d.2694093.0\n",
"resnet50_50-epochs_stl10.zip\n",
"sample_data\n",
"training.log\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3_nypQVEv-hn"
},
"source": [
"from torch.utils.data import DataLoader\n",
"import torchvision.transforms as transforms\n",
"from torchvision import datasets"
],
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lDfbL3w_Z0Od",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7532966e-1c4a-4641-c928-4cda14c53389"
},
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(\"Using device:\", device)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": [
"Using device: cuda\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "BfIPl0G6_RrT"
},
"source": [
"def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n",
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
" transform=transforms.ToTensor())\n",
"\n",
" train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
" \n",
" test_dataset = datasets.STL10('./data', split='test', download=download,\n",
" transform=transforms.ToTensor())\n",
"\n",
" test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n",
" num_workers=10, drop_last=False, shuffle=shuffle)\n",
" return train_loader, test_loader\n",
"\n",
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n",
" train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
" transform=transforms.ToTensor())\n",
"\n",
" train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
" \n",
" test_dataset = datasets.CIFAR10('./data', train=False, download=download,\n",
" transform=transforms.ToTensor())\n",
"\n",
" test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n",
" num_workers=10, drop_last=False, shuffle=shuffle)\n",
" return train_loader, test_loader"
],
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6N8lYkbmDTaK"
},
"source": [
"with open(os.path.join('./config.yml')) as file:\n",
" config = yaml.load(file)"
],
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "a18lPD-tIle6"
},
"source": [
"if config.arch == 'resnet18':\n",
" model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)\n",
"elif config.arch == 'resnet50':\n",
" model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4AIfgq41GuTT"
},
"source": [
"checkpoint = torch.load('checkpoint_0040.pth.tar', map_location=device)\n",
"state_dict = checkpoint['state_dict']\n",
"\n",
"for k in list(state_dict.keys()):\n",
"\n",
" if k.startswith('backbone.'):\n",
" if k.startswith('backbone') and not k.startswith('backbone.fc'):\n",
" # remove prefix\n",
" state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
" del state_dict[k]"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VVjA83PPJYWl"
},
"source": [
"log = model.load_state_dict(state_dict, strict=False)\n",
"assert log.missing_keys == ['fc.weight', 'fc.bias']"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_GC0a14uWRr6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 117,
"referenced_widgets": [
"149b9ce8fb68473a837a77431c12281a",
"88cd3db2831e4c13a4a634709700d6b2",
"a88c31d74f5c40a2b24bcff5a35d216c",
"60c6150177694717a622936b830427b5",
"dba019efadee4fdc8c799f309b9a7e70",
"5901c2829a554c8ebbd5926610088041",
"957362a11d174407979cf17012bf9208",
"a4f82234388e4701a02a9f68a177193a"
]
},
"outputId": "4c2558db-921c-425e-f947-6cc746d8c749"
},
"source": [
"if config.dataset_name == 'cifar10':\n",
" train_loader, test_loader = get_cifar10_data_loaders(download=True)\n",
"elif config.dataset_name == 'stl10':\n",
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
"print(\"Dataset:\", config.dataset_name)"
],
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "149b9ce8fb68473a837a77431c12281a",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Extracting ./data/stl10_binary.tar.gz to ./data\n",
"Files already downloaded and verified\n",
"Dataset: stl10\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pYT_KsM0Mnnr"
},
"source": [
"# freeze all layers but the last fc\n",
"for name, param in model.named_parameters():\n",
" if name not in ['fc.weight', 'fc.bias']:\n",
" param.requires_grad = False\n",
"\n",
"parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
"assert len(parameters) == 2 # fc.weight, fc.bias"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aPVh1S_eMRDU"
},
"source": [
"optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
"criterion = torch.nn.CrossEntropyLoss().to(device)"
],
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "edr6RhP2PdVq"
},
"source": [
"def accuracy(output, target, topk=(1,)):\n",
" \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n",
" with torch.no_grad():\n",
" maxk = max(topk)\n",
" batch_size = target.size(0)\n",
"\n",
" _, pred = output.topk(maxk, 1, True, True)\n",
" pred = pred.t()\n",
" correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
"\n",
" res = []\n",
" for k in topk:\n",
" correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n",
" res.append(correct_k.mul_(100.0 / batch_size))\n",
" return res"
],
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qOder0dAMI7X",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5f723b91-5a5e-43eb-ca01-a9b5ae2f1346"
},
"source": [
"epochs = 100\n",
"for epoch in range(epochs):\n",
" top1_train_accuracy = 0\n",
" for counter, (x_batch, y_batch) in enumerate(train_loader):\n",
" x_batch = x_batch.to(device)\n",
" y_batch = y_batch.to(device)\n",
"\n",
" logits = model(x_batch)\n",
" loss = criterion(logits, y_batch)\n",
" top1 = accuracy(logits, y_batch, topk=(1,))\n",
" top1_train_accuracy += top1[0]\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" top1_train_accuracy /= (counter + 1)\n",
" top1_accuracy = 0\n",
" top5_accuracy = 0\n",
" for counter, (x_batch, y_batch) in enumerate(test_loader):\n",
" x_batch = x_batch.to(device)\n",
" y_batch = y_batch.to(device)\n",
"\n",
" logits = model(x_batch)\n",
" \n",
" top1, top5 = accuracy(logits, y_batch, topk=(1,5))\n",
" top1_accuracy += top1[0]\n",
" top5_accuracy += top5[0]\n",
" \n",
" top1_accuracy /= (counter + 1)\n",
" top5_accuracy /= (counter + 1)\n",
" print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
],
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 0\tTop1 Train accuracy 28.7109375\tTop1 Test accuracy: 43.75\tTop5 test acc: 93.837890625\n",
"Epoch 1\tTop1 Train accuracy 49.37959671020508\tTop1 Test accuracy: 52.8662109375\tTop5 test acc: 95.439453125\n",
"Epoch 2\tTop1 Train accuracy 55.257354736328125\tTop1 Test accuracy: 56.45263671875\tTop5 test acc: 95.91796875\n",
"Epoch 3\tTop1 Train accuracy 57.51838302612305\tTop1 Test accuracy: 57.39013671875\tTop5 test acc: 96.19384765625\n",
"Epoch 4\tTop1 Train accuracy 58.727020263671875\tTop1 Test accuracy: 58.2568359375\tTop5 test acc: 96.435546875\n",
"Epoch 5\tTop1 Train accuracy 59.677162170410156\tTop1 Test accuracy: 58.7353515625\tTop5 test acc: 96.50390625\n",
"Epoch 6\tTop1 Train accuracy 60.065486907958984\tTop1 Test accuracy: 59.17724609375\tTop5 test acc: 96.708984375\n",
"Epoch 7\tTop1 Train accuracy 60.612361907958984\tTop1 Test accuracy: 59.482421875\tTop5 test acc: 96.74560546875\n",
"Epoch 8\tTop1 Train accuracy 60.827205657958984\tTop1 Test accuracy: 59.66064453125\tTop5 test acc: 96.77490234375\n",
"Epoch 9\tTop1 Train accuracy 61.100643157958984\tTop1 Test accuracy: 60.09521484375\tTop5 test acc: 96.82373046875\n",
"Epoch 10\tTop1 Train accuracy 61.52803421020508\tTop1 Test accuracy: 60.3466796875\tTop5 test acc: 96.82861328125\n",
"Epoch 11\tTop1 Train accuracy 61.80147171020508\tTop1 Test accuracy: 60.6640625\tTop5 test acc: 96.8896484375\n",
"Epoch 12\tTop1 Train accuracy 62.09444046020508\tTop1 Test accuracy: 60.96435546875\tTop5 test acc: 96.99462890625\n",
"Epoch 13\tTop1 Train accuracy 62.541358947753906\tTop1 Test accuracy: 61.13037109375\tTop5 test acc: 97.0068359375\n",
"Epoch 14\tTop1 Train accuracy 62.853858947753906\tTop1 Test accuracy: 61.34033203125\tTop5 test acc: 97.01904296875\n",
"Epoch 15\tTop1 Train accuracy 62.951515197753906\tTop1 Test accuracy: 61.5673828125\tTop5 test acc: 96.99951171875\n",
"Epoch 16\tTop1 Train accuracy 63.400733947753906\tTop1 Test accuracy: 61.806640625\tTop5 test acc: 97.0361328125\n",
"Epoch 17\tTop1 Train accuracy 63.66958236694336\tTop1 Test accuracy: 61.98974609375\tTop5 test acc: 97.0849609375\n",
"Epoch 18\tTop1 Train accuracy 63.82583236694336\tTop1 Test accuracy: 62.265625\tTop5 test acc: 97.07275390625\n",
"Epoch 19\tTop1 Train accuracy 64.1187973022461\tTop1 Test accuracy: 62.412109375\tTop5 test acc: 97.09716796875\n",
"Epoch 20\tTop1 Train accuracy 64.2750473022461\tTop1 Test accuracy: 62.56591796875\tTop5 test acc: 97.12158203125\n",
"Epoch 21\tTop1 Train accuracy 64.4140625\tTop1 Test accuracy: 62.724609375\tTop5 test acc: 97.20703125\n",
"Epoch 22\tTop1 Train accuracy 64.53125\tTop1 Test accuracy: 62.90771484375\tTop5 test acc: 97.255859375\n",
"Epoch 23\tTop1 Train accuracy 64.6484375\tTop1 Test accuracy: 62.95654296875\tTop5 test acc: 97.29248046875\n",
"Epoch 24\tTop1 Train accuracy 64.86328125\tTop1 Test accuracy: 63.12255859375\tTop5 test acc: 97.35595703125\n",
"Epoch 25\tTop1 Train accuracy 65.1344223022461\tTop1 Test accuracy: 63.330078125\tTop5 test acc: 97.40478515625\n",
"Epoch 26\tTop1 Train accuracy 65.3297348022461\tTop1 Test accuracy: 63.3984375\tTop5 test acc: 97.44873046875\n",
"Epoch 27\tTop1 Train accuracy 65.4469223022461\tTop1 Test accuracy: 63.34228515625\tTop5 test acc: 97.412109375\n",
"Epoch 28\tTop1 Train accuracy 65.6227035522461\tTop1 Test accuracy: 63.48876953125\tTop5 test acc: 97.412109375\n",
"Epoch 29\tTop1 Train accuracy 65.85478210449219\tTop1 Test accuracy: 63.56201171875\tTop5 test acc: 97.42431640625\n",
"Epoch 30\tTop1 Train accuracy 66.06732940673828\tTop1 Test accuracy: 63.67431640625\tTop5 test acc: 97.4560546875\n",
"Epoch 31\tTop1 Train accuracy 66.20404815673828\tTop1 Test accuracy: 63.80859375\tTop5 test acc: 97.48046875\n",
"Epoch 32\tTop1 Train accuracy 66.24080657958984\tTop1 Test accuracy: 63.92578125\tTop5 test acc: 97.5048828125\n",
"Epoch 33\tTop1 Train accuracy 66.58777618408203\tTop1 Test accuracy: 63.9990234375\tTop5 test acc: 97.529296875\n",
"Epoch 34\tTop1 Train accuracy 66.70496368408203\tTop1 Test accuracy: 64.1455078125\tTop5 test acc: 97.51708984375\n",
"Epoch 35\tTop1 Train accuracy 66.80261993408203\tTop1 Test accuracy: 64.20654296875\tTop5 test acc: 97.529296875\n",
"Epoch 36\tTop1 Train accuracy 66.91980743408203\tTop1 Test accuracy: 64.32861328125\tTop5 test acc: 97.51708984375\n",
"Epoch 37\tTop1 Train accuracy 66.93933868408203\tTop1 Test accuracy: 64.3896484375\tTop5 test acc: 97.51708984375\n",
"Epoch 38\tTop1 Train accuracy 66.97840118408203\tTop1 Test accuracy: 64.47021484375\tTop5 test acc: 97.529296875\n",
"Epoch 39\tTop1 Train accuracy 67.11282348632812\tTop1 Test accuracy: 64.53125\tTop5 test acc: 97.56591796875\n",
"Epoch 40\tTop1 Train accuracy 67.24954223632812\tTop1 Test accuracy: 64.6044921875\tTop5 test acc: 97.6025390625\n",
"Epoch 41\tTop1 Train accuracy 67.34949493408203\tTop1 Test accuracy: 64.62890625\tTop5 test acc: 97.59033203125\n",
"Epoch 42\tTop1 Train accuracy 67.42761993408203\tTop1 Test accuracy: 64.7265625\tTop5 test acc: 97.6025390625\n",
"Epoch 43\tTop1 Train accuracy 67.52527618408203\tTop1 Test accuracy: 64.84375\tTop5 test acc: 97.61474609375\n",
"Epoch 44\tTop1 Train accuracy 67.58386993408203\tTop1 Test accuracy: 64.87548828125\tTop5 test acc: 97.61474609375\n",
"Epoch 45\tTop1 Train accuracy 67.64246368408203\tTop1 Test accuracy: 64.9365234375\tTop5 test acc: 97.626953125\n",
"Epoch 46\tTop1 Train accuracy 67.75735473632812\tTop1 Test accuracy: 65.0341796875\tTop5 test acc: 97.66357421875\n",
"Epoch 47\tTop1 Train accuracy 67.85501098632812\tTop1 Test accuracy: 65.1318359375\tTop5 test acc: 97.7001953125\n",
"Epoch 48\tTop1 Train accuracy 67.89407348632812\tTop1 Test accuracy: 65.1318359375\tTop5 test acc: 97.73681640625\n",
"Epoch 49\tTop1 Train accuracy 67.95266723632812\tTop1 Test accuracy: 65.15625\tTop5 test acc: 97.73681640625\n",
"Epoch 50\tTop1 Train accuracy 68.01126098632812\tTop1 Test accuracy: 65.21728515625\tTop5 test acc: 97.76123046875\n",
"Epoch 51\tTop1 Train accuracy 68.05032348632812\tTop1 Test accuracy: 65.29052734375\tTop5 test acc: 97.7490234375\n",
"Epoch 52\tTop1 Train accuracy 68.05032348632812\tTop1 Test accuracy: 65.3564453125\tTop5 test acc: 97.78564453125\n",
"Epoch 53\tTop1 Train accuracy 68.20657348632812\tTop1 Test accuracy: 65.3759765625\tTop5 test acc: 97.7978515625\n",
"Epoch 54\tTop1 Train accuracy 68.28469848632812\tTop1 Test accuracy: 65.45654296875\tTop5 test acc: 97.822265625\n",
"Epoch 55\tTop1 Train accuracy 68.41912078857422\tTop1 Test accuracy: 65.46875\tTop5 test acc: 97.8466796875\n",
"Epoch 56\tTop1 Train accuracy 68.45818328857422\tTop1 Test accuracy: 65.5615234375\tTop5 test acc: 97.85888671875\n",
"Epoch 57\tTop1 Train accuracy 68.61443328857422\tTop1 Test accuracy: 65.56640625\tTop5 test acc: 97.87109375\n",
"Epoch 58\tTop1 Train accuracy 68.71208953857422\tTop1 Test accuracy: 65.5859375\tTop5 test acc: 97.90771484375\n",
"Epoch 59\tTop1 Train accuracy 68.69255828857422\tTop1 Test accuracy: 65.64697265625\tTop5 test acc: 97.919921875\n",
"Epoch 60\tTop1 Train accuracy 68.80744934082031\tTop1 Test accuracy: 65.64697265625\tTop5 test acc: 97.93212890625\n",
"Epoch 61\tTop1 Train accuracy 68.94416809082031\tTop1 Test accuracy: 65.72021484375\tTop5 test acc: 97.93212890625\n",
"Epoch 62\tTop1 Train accuracy 69.04182434082031\tTop1 Test accuracy: 65.76904296875\tTop5 test acc: 97.919921875\n",
"Epoch 63\tTop1 Train accuracy 69.06135559082031\tTop1 Test accuracy: 65.84228515625\tTop5 test acc: 97.90771484375\n",
"Epoch 64\tTop1 Train accuracy 69.19807434082031\tTop1 Test accuracy: 65.93505859375\tTop5 test acc: 97.90771484375\n",
"Epoch 65\tTop1 Train accuracy 69.23713684082031\tTop1 Test accuracy: 65.95947265625\tTop5 test acc: 97.9150390625\n",
"Epoch 66\tTop1 Train accuracy 69.25666809082031\tTop1 Test accuracy: 66.0888671875\tTop5 test acc: 97.939453125\n",
"Epoch 67\tTop1 Train accuracy 69.31526184082031\tTop1 Test accuracy: 66.02783203125\tTop5 test acc: 97.939453125\n",
"Epoch 68\tTop1 Train accuracy 69.43014526367188\tTop1 Test accuracy: 66.07666015625\tTop5 test acc: 97.9638671875\n",
"Epoch 69\tTop1 Train accuracy 69.48873901367188\tTop1 Test accuracy: 66.12060546875\tTop5 test acc: 97.9638671875\n",
"Epoch 70\tTop1 Train accuracy 69.50827026367188\tTop1 Test accuracy: 66.083984375\tTop5 test acc: 97.95166015625\n",
"Epoch 71\tTop1 Train accuracy 69.60592651367188\tTop1 Test accuracy: 66.1572265625\tTop5 test acc: 97.9638671875\n",
"Epoch 72\tTop1 Train accuracy 69.68635559082031\tTop1 Test accuracy: 66.2060546875\tTop5 test acc: 97.95166015625\n",
"Epoch 73\tTop1 Train accuracy 69.78170776367188\tTop1 Test accuracy: 66.2744140625\tTop5 test acc: 97.92724609375\n",
"Epoch 74\tTop1 Train accuracy 69.84030151367188\tTop1 Test accuracy: 66.31591796875\tTop5 test acc: 97.92724609375\n",
"Epoch 75\tTop1 Train accuracy 69.89889526367188\tTop1 Test accuracy: 66.328125\tTop5 test acc: 97.9150390625\n",
"Epoch 76\tTop1 Train accuracy 69.93795776367188\tTop1 Test accuracy: 66.41357421875\tTop5 test acc: 97.92724609375\n",
"Epoch 77\tTop1 Train accuracy 69.95748901367188\tTop1 Test accuracy: 66.41357421875\tTop5 test acc: 97.9150390625\n",
"Epoch 78\tTop1 Train accuracy 70.01608276367188\tTop1 Test accuracy: 66.474609375\tTop5 test acc: 97.9150390625\n",
"Epoch 79\tTop1 Train accuracy 69.99655151367188\tTop1 Test accuracy: 66.53564453125\tTop5 test acc: 97.939453125\n",
"Epoch 80\tTop1 Train accuracy 70.01608276367188\tTop1 Test accuracy: 66.56005859375\tTop5 test acc: 97.939453125\n",
"Epoch 81\tTop1 Train accuracy 70.09420776367188\tTop1 Test accuracy: 66.56494140625\tTop5 test acc: 97.939453125\n",
"Epoch 82\tTop1 Train accuracy 70.11373901367188\tTop1 Test accuracy: 66.650390625\tTop5 test acc: 97.939453125\n",
"Epoch 83\tTop1 Train accuracy 70.19186401367188\tTop1 Test accuracy: 66.71142578125\tTop5 test acc: 97.92724609375\n",
"Epoch 84\tTop1 Train accuracy 70.26998901367188\tTop1 Test accuracy: 66.7236328125\tTop5 test acc: 97.90283203125\n",
"Epoch 85\tTop1 Train accuracy 70.32858276367188\tTop1 Test accuracy: 66.73583984375\tTop5 test acc: 97.90283203125\n",
"Epoch 86\tTop1 Train accuracy 70.32858276367188\tTop1 Test accuracy: 66.748046875\tTop5 test acc: 97.890625\n",
"Epoch 87\tTop1 Train accuracy 70.46530151367188\tTop1 Test accuracy: 66.7724609375\tTop5 test acc: 97.890625\n",
"Epoch 88\tTop1 Train accuracy 70.52389526367188\tTop1 Test accuracy: 66.78466796875\tTop5 test acc: 97.90283203125\n",
"Epoch 89\tTop1 Train accuracy 70.56295776367188\tTop1 Test accuracy: 66.78466796875\tTop5 test acc: 97.890625\n",
"Epoch 90\tTop1 Train accuracy 70.68014526367188\tTop1 Test accuracy: 66.83349609375\tTop5 test acc: 97.87841796875\n",
"Epoch 91\tTop1 Train accuracy 70.77780151367188\tTop1 Test accuracy: 66.826171875\tTop5 test acc: 97.87841796875\n",
"Epoch 92\tTop1 Train accuracy 70.81686401367188\tTop1 Test accuracy: 66.88720703125\tTop5 test acc: 97.87841796875\n",
"Epoch 93\tTop1 Train accuracy 70.85592651367188\tTop1 Test accuracy: 66.8994140625\tTop5 test acc: 97.87841796875\n",
"Epoch 94\tTop1 Train accuracy 70.91452026367188\tTop1 Test accuracy: 66.9482421875\tTop5 test acc: 97.890625\n",
"Epoch 95\tTop1 Train accuracy 71.03170776367188\tTop1 Test accuracy: 66.98486328125\tTop5 test acc: 97.890625\n",
"Epoch 96\tTop1 Train accuracy 71.09030151367188\tTop1 Test accuracy: 67.001953125\tTop5 test acc: 97.91015625\n",
"Epoch 97\tTop1 Train accuracy 71.09030151367188\tTop1 Test accuracy: 67.0263671875\tTop5 test acc: 97.91015625\n",
"Epoch 98\tTop1 Train accuracy 71.12936401367188\tTop1 Test accuracy: 67.06298828125\tTop5 test acc: 97.89794921875\n",
"Epoch 99\tTop1 Train accuracy 71.12936401367188\tTop1 Test accuracy: 67.0751953125\tTop5 test acc: 97.8857421875\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dtYqHZirMNZk"
},
"source": [
""
],
"execution_count": 27,
"outputs": []
}
]
}