mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
Created using Colaboratory
This commit is contained in:
parent
d1d59400fe
commit
63e46b0d0e
@ -24,10 +24,7 @@
|
||||
"provenance": [],
|
||||
"include_colab_link": true
|
||||
},
|
||||
"accelerator": "GPU",
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {}
|
||||
}
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
@ -228,7 +225,7 @@
|
||||
"id": "BfIPl0G6_RrT"
|
||||
},
|
||||
"source": [
|
||||
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
|
||||
"def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n",
|
||||
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
|
||||
" transform=transforms.ToTensor())\n",
|
||||
"\n",
|
||||
@ -242,7 +239,7 @@
|
||||
" num_workers=10, drop_last=False, shuffle=shuffle)\n",
|
||||
" return train_loader, test_loader\n",
|
||||
"\n",
|
||||
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=128):\n",
|
||||
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n",
|
||||
" train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
|
||||
" transform=transforms.ToTensor())\n",
|
||||
"\n",
|
||||
@ -282,7 +279,7 @@
|
||||
"elif config.arch == 'resnet50':\n",
|
||||
" model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
|
||||
],
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
@ -302,7 +299,7 @@
|
||||
" state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
|
||||
" del state_dict[k]"
|
||||
],
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
@ -314,7 +311,7 @@
|
||||
"log = model.load_state_dict(state_dict, strict=False)\n",
|
||||
"assert log.missing_keys == ['fc.weight', 'fc.bias']"
|
||||
],
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
@ -337,7 +334,7 @@
|
||||
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
|
||||
"print(\"Dataset:\", config.dataset_name)"
|
||||
],
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -387,7 +384,7 @@
|
||||
"parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
|
||||
"assert len(parameters) == 2 # fc.weight, fc.bias"
|
||||
],
|
||||
"execution_count": 15,
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
@ -399,7 +396,7 @@
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
|
||||
"criterion = torch.nn.CrossEntropyLoss().to(device)"
|
||||
],
|
||||
"execution_count": 16,
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
@ -424,7 +421,7 @@
|
||||
" res.append(correct_k.mul_(100.0 / batch_size))\n",
|
||||
" return res"
|
||||
],
|
||||
"execution_count": 17,
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
@ -470,7 +467,7 @@
|
||||
" top5_accuracy /= (counter + 1)\n",
|
||||
" print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
|
||||
],
|
||||
"execution_count": 18,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -588,7 +585,7 @@
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"execution_count": 18,
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user