Created using Colaboratory

This commit is contained in:
Thalles Silva 2021-01-17 20:28:37 -03:00
parent d1d59400fe
commit 63e46b0d0e

View File

@ -24,10 +24,7 @@
"provenance": [], "provenance": [],
"include_colab_link": true "include_colab_link": true
}, },
"accelerator": "GPU", "accelerator": "GPU"
"widgets": {
"application/vnd.jupyter.widget-state+json": {}
}
}, },
"cells": [ "cells": [
{ {
@ -228,7 +225,7 @@
"id": "BfIPl0G6_RrT" "id": "BfIPl0G6_RrT"
}, },
"source": [ "source": [
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n", "def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n",
" train_dataset = datasets.STL10('./data', split='train', download=download,\n", " train_dataset = datasets.STL10('./data', split='train', download=download,\n",
" transform=transforms.ToTensor())\n", " transform=transforms.ToTensor())\n",
"\n", "\n",
@ -242,7 +239,7 @@
" num_workers=10, drop_last=False, shuffle=shuffle)\n", " num_workers=10, drop_last=False, shuffle=shuffle)\n",
" return train_loader, test_loader\n", " return train_loader, test_loader\n",
"\n", "\n",
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=128):\n", "def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n",
" train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n", " train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
" transform=transforms.ToTensor())\n", " transform=transforms.ToTensor())\n",
"\n", "\n",
@ -282,7 +279,7 @@
"elif config.arch == 'resnet50':\n", "elif config.arch == 'resnet50':\n",
" model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)" " model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
], ],
"execution_count": 11, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@ -302,7 +299,7 @@
" state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n", " state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
" del state_dict[k]" " del state_dict[k]"
], ],
"execution_count": 12, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@ -314,7 +311,7 @@
"log = model.load_state_dict(state_dict, strict=False)\n", "log = model.load_state_dict(state_dict, strict=False)\n",
"assert log.missing_keys == ['fc.weight', 'fc.bias']" "assert log.missing_keys == ['fc.weight', 'fc.bias']"
], ],
"execution_count": 13, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@ -337,7 +334,7 @@
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n", " train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
"print(\"Dataset:\", config.dataset_name)" "print(\"Dataset:\", config.dataset_name)"
], ],
"execution_count": 14, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@ -387,7 +384,7 @@
"parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n", "parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
"assert len(parameters) == 2 # fc.weight, fc.bias" "assert len(parameters) == 2 # fc.weight, fc.bias"
], ],
"execution_count": 15, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@ -399,7 +396,7 @@
"optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
"criterion = torch.nn.CrossEntropyLoss().to(device)" "criterion = torch.nn.CrossEntropyLoss().to(device)"
], ],
"execution_count": 16, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@ -424,7 +421,7 @@
" res.append(correct_k.mul_(100.0 / batch_size))\n", " res.append(correct_k.mul_(100.0 / batch_size))\n",
" return res" " return res"
], ],
"execution_count": 17, "execution_count": null,
"outputs": [] "outputs": []
}, },
{ {
@ -470,7 +467,7 @@
" top5_accuracy /= (counter + 1)\n", " top5_accuracy /= (counter + 1)\n",
" print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")" " print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
], ],
"execution_count": 18, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@ -588,7 +585,7 @@
"source": [ "source": [
"" ""
], ],
"execution_count": 18, "execution_count": null,
"outputs": [] "outputs": []
} }
] ]