mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
592 lines
28 KiB
Plaintext
592 lines
28 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "pytorch",
|
|
"language": "python",
|
|
"name": "pytorch"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.6"
|
|
},
|
|
"colab": {
|
|
"name": "Copy of mini-batch-logistic-regression-evaluator.ipynb",
|
|
"provenance": [],
|
|
"include_colab_link": true
|
|
},
|
|
"accelerator": "GPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "YUemQib7ZE4D"
|
|
},
|
|
"source": [
|
|
"import torch\n",
|
|
"import sys\n",
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"import yaml\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import torchvision"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "WSgRE1CcLqdS",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "e44ac358-6480-4a5f-a358-6eb6ace26c8b"
|
|
},
|
|
"source": [
|
|
"!pip install gdown"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
|
|
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.15.0)\n",
|
|
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.23.0)\n",
|
|
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.41.1)\n",
|
|
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n",
|
|
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n",
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.12.5)\n",
|
|
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.10)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "NOIJEui1ZziV"
|
|
},
|
|
"source": [
|
|
"def get_file_id_by_model(folder_name):\n",
|
|
" file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',\n",
|
|
" 'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C'}\n",
|
|
" return file_id.get(folder_name, \"Model not found.\")"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "G7YMxsvEZMrX",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "36932a7d-c7e5-492a-f37d-8be6b18f787a"
|
|
},
|
|
"source": [
|
|
"folder_name = 'resnet18_100-epochs_stl10'\n",
|
|
"file_id = get_file_id_by_model(folder_name)\n",
|
|
"print(folder_name, file_id)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"resnet18_100-epochs_stl10 14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "PWZ8fet_YoJm",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "8d52756d-707b-4a3f-9e8c-0d191408deab"
|
|
},
|
|
"source": [
|
|
"# download and extract model files\n",
|
|
"os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
|
|
"os.system('unzip {}'.format(folder_name))\n",
|
|
"!ls"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"checkpoint_0100.pth.tar\n",
|
|
"config.yml\n",
|
|
"events.out.tfevents.1610901470.4cb2c837708d.2683858.0\n",
|
|
"resnet18_100-epochs_stl10.zip\n",
|
|
"sample_data\n",
|
|
"training.log\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "ooyhd8piZ1w1",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "6ffb73aa-35c5-4df2-bd1f-6de6a235a9e5"
|
|
},
|
|
"source": [
|
|
"!unzip resnet18_100-epochs_stl10"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Archive: resnet18_100-epochs_stl10.zip\n",
|
|
"replace checkpoint_0100.pth.tar? [y]es, [n]o, [A]ll, [N]one, [r]ename: A\n",
|
|
" inflating: checkpoint_0100.pth.tar \n",
|
|
" inflating: config.yml \n",
|
|
" inflating: events.out.tfevents.1610901470.4cb2c837708d.2683858.0 \n",
|
|
" inflating: training.log \n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "3_nypQVEv-hn"
|
|
},
|
|
"source": [
|
|
"from torch.utils.data import DataLoader\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"from torchvision import datasets"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "lDfbL3w_Z0Od",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "5f58bd9b-4428-4b8c-e271-b47ca6694f34"
|
|
},
|
|
"source": [
|
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
|
"print(\"Using device:\", device)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using device: cuda\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "BfIPl0G6_RrT"
|
|
},
|
|
"source": [
|
|
"def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n",
|
|
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
|
|
" transform=transforms.ToTensor())\n",
|
|
"\n",
|
|
" train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
|
|
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
|
|
" \n",
|
|
" test_dataset = datasets.STL10('./data', split='test', download=download,\n",
|
|
" transform=transforms.ToTensor())\n",
|
|
"\n",
|
|
" test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n",
|
|
" num_workers=10, drop_last=False, shuffle=shuffle)\n",
|
|
" return train_loader, test_loader\n",
|
|
"\n",
|
|
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n",
|
|
" train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
|
|
" transform=transforms.ToTensor())\n",
|
|
"\n",
|
|
" train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
|
|
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
|
|
" \n",
|
|
" test_dataset = datasets.CIFAR10('./data', train=False, download=download,\n",
|
|
" transform=transforms.ToTensor())\n",
|
|
"\n",
|
|
" test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n",
|
|
" num_workers=10, drop_last=False, shuffle=shuffle)\n",
|
|
" return train_loader, test_loader"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "6N8lYkbmDTaK"
|
|
},
|
|
"source": [
|
|
"with open(os.path.join('./config.yml')) as file:\n",
|
|
" config = yaml.load(file)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "a18lPD-tIle6"
|
|
},
|
|
"source": [
|
|
"if config.arch == 'resnet18':\n",
|
|
" model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)\n",
|
|
"elif config.arch == 'resnet50':\n",
|
|
" model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "4AIfgq41GuTT"
|
|
},
|
|
"source": [
|
|
"checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)\n",
|
|
"state_dict = checkpoint['state_dict']\n",
|
|
"\n",
|
|
"for k in list(state_dict.keys()):\n",
|
|
"\n",
|
|
" if k.startswith('backbone.'):\n",
|
|
" if k.startswith('backbone') and not k.startswith('backbone.fc'):\n",
|
|
" # remove prefix\n",
|
|
" state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
|
|
" del state_dict[k]"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "VVjA83PPJYWl"
|
|
},
|
|
"source": [
|
|
"log = model.load_state_dict(state_dict, strict=False)\n",
|
|
"assert log.missing_keys == ['fc.weight', 'fc.bias']"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "_GC0a14uWRr6",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 102,
|
|
"referenced_widgets": [
|
|
"48ebf2f69d1f4f5a9208cd2923eb5eac"
|
|
]
|
|
},
|
|
"outputId": "6c3b86ad-b568-4c68-c1fb-1f7b2abbb6aa"
|
|
},
|
|
"source": [
|
|
"if config.dataset_name == 'cifar10':\n",
|
|
" train_loader, test_loader = get_cifar10_data_loaders(download=True)\n",
|
|
"elif config.dataset_name == 'stl10':\n",
|
|
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
|
|
"print(\"Dataset:\", config.dataset_name)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz\n"
|
|
],
|
|
"name": "stdout"
|
|
},
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "48ebf2f69d1f4f5a9208cd2923eb5eac",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
}
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Extracting ./data/stl10_binary.tar.gz to ./data\n",
|
|
"Files already downloaded and verified\n",
|
|
"Dataset: stl10\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "pYT_KsM0Mnnr"
|
|
},
|
|
"source": [
|
|
"# freeze all layers but the last fc\n",
|
|
"for name, param in model.named_parameters():\n",
|
|
" if name not in ['fc.weight', 'fc.bias']:\n",
|
|
" param.requires_grad = False\n",
|
|
"\n",
|
|
"parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
|
|
"assert len(parameters) == 2 # fc.weight, fc.bias"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "aPVh1S_eMRDU"
|
|
},
|
|
"source": [
|
|
"optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
|
|
"criterion = torch.nn.CrossEntropyLoss().to(device)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "edr6RhP2PdVq"
|
|
},
|
|
"source": [
|
|
"def accuracy(output, target, topk=(1,)):\n",
|
|
" \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n",
|
|
" with torch.no_grad():\n",
|
|
" maxk = max(topk)\n",
|
|
" batch_size = target.size(0)\n",
|
|
"\n",
|
|
" _, pred = output.topk(maxk, 1, True, True)\n",
|
|
" pred = pred.t()\n",
|
|
" correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
|
|
"\n",
|
|
" res = []\n",
|
|
" for k in topk:\n",
|
|
" correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n",
|
|
" res.append(correct_k.mul_(100.0 / batch_size))\n",
|
|
" return res"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "qOder0dAMI7X",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "d6127d8e-836f-4e69-a344-fee7e836d63a"
|
|
},
|
|
"source": [
|
|
"epochs = 100\n",
|
|
"for epoch in range(epochs):\n",
|
|
" top1_train_accuracy = 0\n",
|
|
" for counter, (x_batch, y_batch) in enumerate(train_loader):\n",
|
|
" x_batch = x_batch.to(device)\n",
|
|
" y_batch = y_batch.to(device)\n",
|
|
"\n",
|
|
" logits = model(x_batch)\n",
|
|
" loss = criterion(logits, y_batch)\n",
|
|
" top1 = accuracy(logits, y_batch, topk=(1,))\n",
|
|
" top1_train_accuracy += top1[0]\n",
|
|
"\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" top1_train_accuracy /= (counter + 1)\n",
|
|
" top1_accuracy = 0\n",
|
|
" top5_accuracy = 0\n",
|
|
" for counter, (x_batch, y_batch) in enumerate(test_loader):\n",
|
|
" x_batch = x_batch.to(device)\n",
|
|
" y_batch = y_batch.to(device)\n",
|
|
"\n",
|
|
" logits = model(x_batch)\n",
|
|
" \n",
|
|
" top1, top5 = accuracy(logits, y_batch, topk=(1,5))\n",
|
|
" top1_accuracy += top1[0]\n",
|
|
" top5_accuracy += top5[0]\n",
|
|
" \n",
|
|
" top1_accuracy /= (counter + 1)\n",
|
|
" top5_accuracy /= (counter + 1)\n",
|
|
" print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 0\tTop1 Train accuracy 27.890625\tTop1 Test accuracy: 42.05322265625\tTop5 test acc: 93.29833984375\n",
|
|
"Epoch 1\tTop1 Train accuracy 49.921875\tTop1 Test accuracy: 54.45556640625\tTop5 test acc: 96.1181640625\n",
|
|
"Epoch 2\tTop1 Train accuracy 57.3828125\tTop1 Test accuracy: 58.9599609375\tTop5 test acc: 96.9482421875\n",
|
|
"Epoch 3\tTop1 Train accuracy 60.01953125\tTop1 Test accuracy: 60.38818359375\tTop5 test acc: 97.03369140625\n",
|
|
"Epoch 4\tTop1 Train accuracy 61.7578125\tTop1 Test accuracy: 61.572265625\tTop5 test acc: 97.1923828125\n",
|
|
"Epoch 5\tTop1 Train accuracy 62.91015625\tTop1 Test accuracy: 62.21923828125\tTop5 test acc: 97.30224609375\n",
|
|
"Epoch 6\tTop1 Train accuracy 63.57421875\tTop1 Test accuracy: 62.6220703125\tTop5 test acc: 97.4365234375\n",
|
|
"Epoch 7\tTop1 Train accuracy 64.12109375\tTop1 Test accuracy: 63.18359375\tTop5 test acc: 97.55859375\n",
|
|
"Epoch 8\tTop1 Train accuracy 64.82421875\tTop1 Test accuracy: 63.51318359375\tTop5 test acc: 97.57080078125\n",
|
|
"Epoch 9\tTop1 Train accuracy 65.17578125\tTop1 Test accuracy: 63.80615234375\tTop5 test acc: 97.59521484375\n",
|
|
"Epoch 10\tTop1 Train accuracy 65.5859375\tTop1 Test accuracy: 64.14794921875\tTop5 test acc: 97.6318359375\n",
|
|
"Epoch 11\tTop1 Train accuracy 65.80078125\tTop1 Test accuracy: 64.51416015625\tTop5 test acc: 97.61962890625\n",
|
|
"Epoch 12\tTop1 Train accuracy 66.03515625\tTop1 Test accuracy: 64.70947265625\tTop5 test acc: 97.69287109375\n",
|
|
"Epoch 13\tTop1 Train accuracy 66.42578125\tTop1 Test accuracy: 64.88037109375\tTop5 test acc: 97.705078125\n",
|
|
"Epoch 14\tTop1 Train accuracy 66.9140625\tTop1 Test accuracy: 65.07568359375\tTop5 test acc: 97.76611328125\n",
|
|
"Epoch 15\tTop1 Train accuracy 67.265625\tTop1 Test accuracy: 65.24658203125\tTop5 test acc: 97.81494140625\n",
|
|
"Epoch 16\tTop1 Train accuracy 67.48046875\tTop1 Test accuracy: 65.46630859375\tTop5 test acc: 97.8515625\n",
|
|
"Epoch 17\tTop1 Train accuracy 67.6171875\tTop1 Test accuracy: 65.71044921875\tTop5 test acc: 97.86376953125\n",
|
|
"Epoch 18\tTop1 Train accuracy 67.83203125\tTop1 Test accuracy: 65.966796875\tTop5 test acc: 97.8759765625\n",
|
|
"Epoch 19\tTop1 Train accuracy 68.0078125\tTop1 Test accuracy: 66.05224609375\tTop5 test acc: 97.88818359375\n",
|
|
"Epoch 20\tTop1 Train accuracy 68.1640625\tTop1 Test accuracy: 66.17431640625\tTop5 test acc: 97.88818359375\n",
|
|
"Epoch 21\tTop1 Train accuracy 68.37890625\tTop1 Test accuracy: 66.30859375\tTop5 test acc: 97.900390625\n",
|
|
"Epoch 22\tTop1 Train accuracy 68.49609375\tTop1 Test accuracy: 66.50390625\tTop5 test acc: 97.88818359375\n",
|
|
"Epoch 23\tTop1 Train accuracy 68.75\tTop1 Test accuracy: 66.6259765625\tTop5 test acc: 97.91259765625\n",
|
|
"Epoch 24\tTop1 Train accuracy 68.90625\tTop1 Test accuracy: 66.68701171875\tTop5 test acc: 97.96142578125\n",
|
|
"Epoch 25\tTop1 Train accuracy 68.984375\tTop1 Test accuracy: 66.8212890625\tTop5 test acc: 97.998046875\n",
|
|
"Epoch 26\tTop1 Train accuracy 69.39453125\tTop1 Test accuracy: 66.9677734375\tTop5 test acc: 98.0224609375\n",
|
|
"Epoch 27\tTop1 Train accuracy 69.4921875\tTop1 Test accuracy: 67.1142578125\tTop5 test acc: 98.01025390625\n",
|
|
"Epoch 28\tTop1 Train accuracy 69.6484375\tTop1 Test accuracy: 67.1630859375\tTop5 test acc: 98.0224609375\n",
|
|
"Epoch 29\tTop1 Train accuracy 69.7265625\tTop1 Test accuracy: 67.19970703125\tTop5 test acc: 98.03466796875\n",
|
|
"Epoch 30\tTop1 Train accuracy 69.74609375\tTop1 Test accuracy: 67.24853515625\tTop5 test acc: 98.05908203125\n",
|
|
"Epoch 31\tTop1 Train accuracy 69.921875\tTop1 Test accuracy: 67.37060546875\tTop5 test acc: 98.03466796875\n",
|
|
"Epoch 32\tTop1 Train accuracy 70.078125\tTop1 Test accuracy: 67.46826171875\tTop5 test acc: 98.03466796875\n",
|
|
"Epoch 33\tTop1 Train accuracy 70.25390625\tTop1 Test accuracy: 67.5048828125\tTop5 test acc: 98.0712890625\n",
|
|
"Epoch 34\tTop1 Train accuracy 70.33203125\tTop1 Test accuracy: 67.59033203125\tTop5 test acc: 98.095703125\n",
|
|
"Epoch 35\tTop1 Train accuracy 70.48828125\tTop1 Test accuracy: 67.73681640625\tTop5 test acc: 98.13232421875\n",
|
|
"Epoch 36\tTop1 Train accuracy 70.5859375\tTop1 Test accuracy: 67.83447265625\tTop5 test acc: 98.1201171875\n",
|
|
"Epoch 37\tTop1 Train accuracy 70.625\tTop1 Test accuracy: 67.85888671875\tTop5 test acc: 98.13232421875\n",
|
|
"Epoch 38\tTop1 Train accuracy 70.78125\tTop1 Test accuracy: 67.88330078125\tTop5 test acc: 98.13232421875\n",
|
|
"Epoch 39\tTop1 Train accuracy 70.91796875\tTop1 Test accuracy: 67.919921875\tTop5 test acc: 98.10791015625\n",
|
|
"Epoch 40\tTop1 Train accuracy 70.95703125\tTop1 Test accuracy: 67.95654296875\tTop5 test acc: 98.10791015625\n",
|
|
"Epoch 41\tTop1 Train accuracy 71.03515625\tTop1 Test accuracy: 68.00537109375\tTop5 test acc: 98.1201171875\n",
|
|
"Epoch 42\tTop1 Train accuracy 71.07421875\tTop1 Test accuracy: 68.06640625\tTop5 test acc: 98.15673828125\n",
|
|
"Epoch 43\tTop1 Train accuracy 71.15234375\tTop1 Test accuracy: 68.12744140625\tTop5 test acc: 98.15673828125\n",
|
|
"Epoch 44\tTop1 Train accuracy 71.2109375\tTop1 Test accuracy: 68.1396484375\tTop5 test acc: 98.1689453125\n",
|
|
"Epoch 45\tTop1 Train accuracy 71.25\tTop1 Test accuracy: 68.1396484375\tTop5 test acc: 98.1689453125\n",
|
|
"Epoch 46\tTop1 Train accuracy 71.46484375\tTop1 Test accuracy: 68.15185546875\tTop5 test acc: 98.193359375\n",
|
|
"Epoch 47\tTop1 Train accuracy 71.58203125\tTop1 Test accuracy: 68.22509765625\tTop5 test acc: 98.2177734375\n",
|
|
"Epoch 48\tTop1 Train accuracy 71.6796875\tTop1 Test accuracy: 68.27392578125\tTop5 test acc: 98.22998046875\n",
|
|
"Epoch 49\tTop1 Train accuracy 71.8359375\tTop1 Test accuracy: 68.3349609375\tTop5 test acc: 98.22998046875\n",
|
|
"Epoch 50\tTop1 Train accuracy 71.93359375\tTop1 Test accuracy: 68.44482421875\tTop5 test acc: 98.2421875\n",
|
|
"Epoch 51\tTop1 Train accuracy 72.01171875\tTop1 Test accuracy: 68.4814453125\tTop5 test acc: 98.2177734375\n",
|
|
"Epoch 52\tTop1 Train accuracy 72.0703125\tTop1 Test accuracy: 68.505859375\tTop5 test acc: 98.2177734375\n",
|
|
"Epoch 53\tTop1 Train accuracy 72.2265625\tTop1 Test accuracy: 68.54248046875\tTop5 test acc: 98.22998046875\n",
|
|
"Epoch 54\tTop1 Train accuracy 72.24609375\tTop1 Test accuracy: 68.5791015625\tTop5 test acc: 98.22998046875\n",
|
|
"Epoch 55\tTop1 Train accuracy 72.34375\tTop1 Test accuracy: 68.65234375\tTop5 test acc: 98.25439453125\n",
|
|
"Epoch 56\tTop1 Train accuracy 72.421875\tTop1 Test accuracy: 68.71337890625\tTop5 test acc: 98.3154296875\n",
|
|
"Epoch 57\tTop1 Train accuracy 72.51953125\tTop1 Test accuracy: 68.71337890625\tTop5 test acc: 98.3154296875\n",
|
|
"Epoch 58\tTop1 Train accuracy 72.94921875\tTop1 Test accuracy: 68.76220703125\tTop5 test acc: 98.3154296875\n",
|
|
"Epoch 59\tTop1 Train accuracy 72.98828125\tTop1 Test accuracy: 68.83544921875\tTop5 test acc: 98.3154296875\n",
|
|
"Epoch 60\tTop1 Train accuracy 73.0859375\tTop1 Test accuracy: 68.88427734375\tTop5 test acc: 98.30322265625\n",
|
|
"Epoch 61\tTop1 Train accuracy 73.18359375\tTop1 Test accuracy: 68.896484375\tTop5 test acc: 98.32763671875\n",
|
|
"Epoch 62\tTop1 Train accuracy 73.3984375\tTop1 Test accuracy: 68.88427734375\tTop5 test acc: 98.33984375\n",
|
|
"Epoch 63\tTop1 Train accuracy 73.4375\tTop1 Test accuracy: 68.95751953125\tTop5 test acc: 98.33984375\n",
|
|
"Epoch 64\tTop1 Train accuracy 73.515625\tTop1 Test accuracy: 68.994140625\tTop5 test acc: 98.32763671875\n",
|
|
"Epoch 65\tTop1 Train accuracy 73.57421875\tTop1 Test accuracy: 68.9697265625\tTop5 test acc: 98.3154296875\n",
|
|
"Epoch 66\tTop1 Train accuracy 73.61328125\tTop1 Test accuracy: 69.03076171875\tTop5 test acc: 98.32763671875\n",
|
|
"Epoch 67\tTop1 Train accuracy 73.671875\tTop1 Test accuracy: 69.07958984375\tTop5 test acc: 98.3154296875\n",
|
|
"Epoch 68\tTop1 Train accuracy 73.7109375\tTop1 Test accuracy: 69.12841796875\tTop5 test acc: 98.3154296875\n",
|
|
"Epoch 69\tTop1 Train accuracy 73.8671875\tTop1 Test accuracy: 69.20166015625\tTop5 test acc: 98.3154296875\n",
|
|
"Epoch 70\tTop1 Train accuracy 73.984375\tTop1 Test accuracy: 69.25048828125\tTop5 test acc: 98.33984375\n",
|
|
"Epoch 71\tTop1 Train accuracy 74.00390625\tTop1 Test accuracy: 69.2626953125\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 72\tTop1 Train accuracy 74.00390625\tTop1 Test accuracy: 69.3115234375\tTop5 test acc: 98.33984375\n",
|
|
"Epoch 73\tTop1 Train accuracy 74.0234375\tTop1 Test accuracy: 69.34814453125\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 74\tTop1 Train accuracy 74.140625\tTop1 Test accuracy: 69.37255859375\tTop5 test acc: 98.33984375\n",
|
|
"Epoch 75\tTop1 Train accuracy 74.23828125\tTop1 Test accuracy: 69.4091796875\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 76\tTop1 Train accuracy 74.31640625\tTop1 Test accuracy: 69.4091796875\tTop5 test acc: 98.37646484375\n",
|
|
"Epoch 77\tTop1 Train accuracy 74.43359375\tTop1 Test accuracy: 69.4091796875\tTop5 test acc: 98.3642578125\n",
|
|
"Epoch 78\tTop1 Train accuracy 74.55078125\tTop1 Test accuracy: 69.3603515625\tTop5 test acc: 98.3642578125\n",
|
|
"Epoch 79\tTop1 Train accuracy 74.58984375\tTop1 Test accuracy: 69.37255859375\tTop5 test acc: 98.3642578125\n",
|
|
"Epoch 80\tTop1 Train accuracy 74.609375\tTop1 Test accuracy: 69.42138671875\tTop5 test acc: 98.3642578125\n",
|
|
"Epoch 81\tTop1 Train accuracy 74.6484375\tTop1 Test accuracy: 69.49462890625\tTop5 test acc: 98.3642578125\n",
|
|
"Epoch 82\tTop1 Train accuracy 74.6875\tTop1 Test accuracy: 69.47021484375\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 83\tTop1 Train accuracy 74.7265625\tTop1 Test accuracy: 69.5556640625\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 84\tTop1 Train accuracy 74.78515625\tTop1 Test accuracy: 69.59228515625\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 85\tTop1 Train accuracy 74.8828125\tTop1 Test accuracy: 69.6533203125\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 86\tTop1 Train accuracy 74.94140625\tTop1 Test accuracy: 69.677734375\tTop5 test acc: 98.3642578125\n",
|
|
"Epoch 87\tTop1 Train accuracy 75.0390625\tTop1 Test accuracy: 69.7509765625\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 88\tTop1 Train accuracy 75.0390625\tTop1 Test accuracy: 69.71435546875\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 89\tTop1 Train accuracy 75.1171875\tTop1 Test accuracy: 69.775390625\tTop5 test acc: 98.33984375\n",
|
|
"Epoch 90\tTop1 Train accuracy 75.21484375\tTop1 Test accuracy: 69.7509765625\tTop5 test acc: 98.33984375\n",
|
|
"Epoch 91\tTop1 Train accuracy 75.25390625\tTop1 Test accuracy: 69.82421875\tTop5 test acc: 98.32763671875\n",
|
|
"Epoch 92\tTop1 Train accuracy 75.29296875\tTop1 Test accuracy: 69.86083984375\tTop5 test acc: 98.33984375\n",
|
|
"Epoch 93\tTop1 Train accuracy 75.33203125\tTop1 Test accuracy: 69.88525390625\tTop5 test acc: 98.35205078125\n",
|
|
"Epoch 94\tTop1 Train accuracy 75.37109375\tTop1 Test accuracy: 69.81201171875\tTop5 test acc: 98.3642578125\n",
|
|
"Epoch 95\tTop1 Train accuracy 75.37109375\tTop1 Test accuracy: 69.83642578125\tTop5 test acc: 98.37646484375\n",
|
|
"Epoch 96\tTop1 Train accuracy 75.37109375\tTop1 Test accuracy: 69.83642578125\tTop5 test acc: 98.37646484375\n",
|
|
"Epoch 97\tTop1 Train accuracy 75.41015625\tTop1 Test accuracy: 69.86083984375\tTop5 test acc: 98.37646484375\n",
|
|
"Epoch 98\tTop1 Train accuracy 75.41015625\tTop1 Test accuracy: 69.90966796875\tTop5 test acc: 98.37646484375\n",
|
|
"Epoch 99\tTop1 Train accuracy 75.46875\tTop1 Test accuracy: 69.921875\tTop5 test acc: 98.37646484375\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "dtYqHZirMNZk"
|
|
},
|
|
"source": [
|
|
""
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |