Created using Colaboratory

This commit is contained in:
Thalles Silva 2021-01-17 20:28:37 -03:00
parent d1d59400fe
commit 63e46b0d0e

View File

@ -24,10 +24,7 @@
"provenance": [],
"include_colab_link": true
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {}
}
"accelerator": "GPU"
},
"cells": [
{
@ -228,7 +225,7 @@
"id": "BfIPl0G6_RrT"
},
"source": [
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
"def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n",
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
" transform=transforms.ToTensor())\n",
"\n",
@ -242,7 +239,7 @@
" num_workers=10, drop_last=False, shuffle=shuffle)\n",
" return train_loader, test_loader\n",
"\n",
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=128):\n",
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n",
" train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
" transform=transforms.ToTensor())\n",
"\n",
@ -282,7 +279,7 @@
"elif config.arch == 'resnet50':\n",
" model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
],
"execution_count": 11,
"execution_count": null,
"outputs": []
},
{
@ -302,7 +299,7 @@
" state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
" del state_dict[k]"
],
"execution_count": 12,
"execution_count": null,
"outputs": []
},
{
@ -314,7 +311,7 @@
"log = model.load_state_dict(state_dict, strict=False)\n",
"assert log.missing_keys == ['fc.weight', 'fc.bias']"
],
"execution_count": 13,
"execution_count": null,
"outputs": []
},
{
@ -337,7 +334,7 @@
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
"print(\"Dataset:\", config.dataset_name)"
],
"execution_count": 14,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
@ -387,7 +384,7 @@
"parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
"assert len(parameters) == 2 # fc.weight, fc.bias"
],
"execution_count": 15,
"execution_count": null,
"outputs": []
},
{
@ -399,7 +396,7 @@
"optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
"criterion = torch.nn.CrossEntropyLoss().to(device)"
],
"execution_count": 16,
"execution_count": null,
"outputs": []
},
{
@ -424,7 +421,7 @@
" res.append(correct_k.mul_(100.0 / batch_size))\n",
" return res"
],
"execution_count": 17,
"execution_count": null,
"outputs": []
},
{
@ -470,7 +467,7 @@
" top5_accuracy /= (counter + 1)\n",
" print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
],
"execution_count": 18,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
@ -588,7 +585,7 @@
"source": [
""
],
"execution_count": 18,
"execution_count": null,
"outputs": []
}
]