mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
Created using Colaboratory
This commit is contained in:
parent
d1d59400fe
commit
63e46b0d0e
@ -24,10 +24,7 @@
|
|||||||
"provenance": [],
|
"provenance": [],
|
||||||
"include_colab_link": true
|
"include_colab_link": true
|
||||||
},
|
},
|
||||||
"accelerator": "GPU",
|
"accelerator": "GPU"
|
||||||
"widgets": {
|
|
||||||
"application/vnd.jupyter.widget-state+json": {}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
@ -228,7 +225,7 @@
|
|||||||
"id": "BfIPl0G6_RrT"
|
"id": "BfIPl0G6_RrT"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
|
"def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n",
|
||||||
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
|
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
|
||||||
" transform=transforms.ToTensor())\n",
|
" transform=transforms.ToTensor())\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -242,7 +239,7 @@
|
|||||||
" num_workers=10, drop_last=False, shuffle=shuffle)\n",
|
" num_workers=10, drop_last=False, shuffle=shuffle)\n",
|
||||||
" return train_loader, test_loader\n",
|
" return train_loader, test_loader\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=128):\n",
|
"def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n",
|
||||||
" train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
|
" train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
|
||||||
" transform=transforms.ToTensor())\n",
|
" transform=transforms.ToTensor())\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -282,7 +279,7 @@
|
|||||||
"elif config.arch == 'resnet50':\n",
|
"elif config.arch == 'resnet50':\n",
|
||||||
" model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
|
" model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
|
||||||
],
|
],
|
||||||
"execution_count": 11,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -302,7 +299,7 @@
|
|||||||
" state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
|
" state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
|
||||||
" del state_dict[k]"
|
" del state_dict[k]"
|
||||||
],
|
],
|
||||||
"execution_count": 12,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -314,7 +311,7 @@
|
|||||||
"log = model.load_state_dict(state_dict, strict=False)\n",
|
"log = model.load_state_dict(state_dict, strict=False)\n",
|
||||||
"assert log.missing_keys == ['fc.weight', 'fc.bias']"
|
"assert log.missing_keys == ['fc.weight', 'fc.bias']"
|
||||||
],
|
],
|
||||||
"execution_count": 13,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -337,7 +334,7 @@
|
|||||||
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
|
" train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
|
||||||
"print(\"Dataset:\", config.dataset_name)"
|
"print(\"Dataset:\", config.dataset_name)"
|
||||||
],
|
],
|
||||||
"execution_count": 14,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
@ -387,7 +384,7 @@
|
|||||||
"parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
|
"parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
|
||||||
"assert len(parameters) == 2 # fc.weight, fc.bias"
|
"assert len(parameters) == 2 # fc.weight, fc.bias"
|
||||||
],
|
],
|
||||||
"execution_count": 15,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -399,7 +396,7 @@
|
|||||||
"optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
|
"optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
|
||||||
"criterion = torch.nn.CrossEntropyLoss().to(device)"
|
"criterion = torch.nn.CrossEntropyLoss().to(device)"
|
||||||
],
|
],
|
||||||
"execution_count": 16,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -424,7 +421,7 @@
|
|||||||
" res.append(correct_k.mul_(100.0 / batch_size))\n",
|
" res.append(correct_k.mul_(100.0 / batch_size))\n",
|
||||||
" return res"
|
" return res"
|
||||||
],
|
],
|
||||||
"execution_count": 17,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -470,7 +467,7 @@
|
|||||||
" top5_accuracy /= (counter + 1)\n",
|
" top5_accuracy /= (counter + 1)\n",
|
||||||
" print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
|
" print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
|
||||||
],
|
],
|
||||||
"execution_count": 18,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
@ -588,7 +585,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
""
|
""
|
||||||
],
|
],
|
||||||
"execution_count": 18,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user