mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
Major refactor, small fixes
This commit is contained in:
parent
2c9536f731
commit
d0112ed55b
@ -2,6 +2,7 @@ from torchvision.transforms import transforms
|
|||||||
from data_aug.gaussian_blur import GaussianBlur
|
from data_aug.gaussian_blur import GaussianBlur
|
||||||
from torchvision import transforms, datasets
|
from torchvision import transforms, datasets
|
||||||
from data_aug.view_generator import ContrastiveLearningViewGenerator
|
from data_aug.view_generator import ContrastiveLearningViewGenerator
|
||||||
|
from exceptions.exceptions import InvalidDatasetSelection
|
||||||
|
|
||||||
|
|
||||||
class ContrastiveLearningDataset:
|
class ContrastiveLearningDataset:
|
||||||
@ -33,5 +34,9 @@ class ContrastiveLearningDataset:
|
|||||||
n_views),
|
n_views),
|
||||||
download=True)}
|
download=True)}
|
||||||
|
|
||||||
dataset = valid_datasets.get(name, 'Invalid dataset option.')()
|
try:
|
||||||
return dataset
|
dataset_fn = valid_datasets[name]
|
||||||
|
except KeyError:
|
||||||
|
raise InvalidDatasetSelection()
|
||||||
|
else:
|
||||||
|
return dataset_fn()
|
||||||
|
@ -4,3 +4,7 @@ class BaseSimCLRException(Exception):
|
|||||||
|
|
||||||
class InvalidBackboneError(BaseSimCLRException):
|
class InvalidBackboneError(BaseSimCLRException):
|
||||||
"""Raised when the choice of backbone Convnet is invalid."""
|
"""Raised when the choice of backbone Convnet is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidDatasetSelection(BaseSimCLRException):
|
||||||
|
"""Raised when the choice of dataset is invalid."""
|
||||||
|
@ -1,37 +1,10 @@
|
|||||||
{
|
{
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0,
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "pytorch",
|
|
||||||
"language": "python",
|
|
||||||
"name": "pytorch"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.6.6"
|
|
||||||
},
|
|
||||||
"colab": {
|
|
||||||
"name": "mini-batch-logistic-regression-evaluator.ipynb",
|
|
||||||
"provenance": [],
|
|
||||||
"include_colab_link": true
|
|
||||||
},
|
|
||||||
"accelerator": "GPU"
|
|
||||||
},
|
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "view-in-github",
|
"colab_type": "text",
|
||||||
"colab_type": "text"
|
"id": "view-in-github"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||||
@ -39,11 +12,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "YUemQib7ZE4D",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "YUemQib7ZE4D"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"import sys\n",
|
"import sys\n",
|
||||||
@ -56,30 +31,30 @@
|
|||||||
"from sklearn.linear_model import LogisticRegression\n",
|
"from sklearn.linear_model import LogisticRegression\n",
|
||||||
"from sklearn import preprocessing\n",
|
"from sklearn import preprocessing\n",
|
||||||
"import importlib.util"
|
"import importlib.util"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "WSgRE1CcLqdS",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "WSgRE1CcLqdS"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install gdown"
|
"!pip install gdown"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "NOIJEui1ZziV",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "NOIJEui1ZziV"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def get_file_id_by_model(folder_name):\n",
|
"def get_file_id_by_model(folder_name):\n",
|
||||||
" file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n",
|
" file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n",
|
||||||
@ -88,92 +63,92 @@
|
|||||||
" 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n",
|
" 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n",
|
||||||
" 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n",
|
" 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n",
|
||||||
" return file_id.get(folder_name, \"Model not found.\")"
|
" return file_id.get(folder_name, \"Model not found.\")"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "G7YMxsvEZMrX",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "G7YMxsvEZMrX"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"folder_name = 'resnet-50_40-epochs'\n",
|
"folder_name = 'resnet-50_40-epochs'\n",
|
||||||
"file_id = get_file_id_by_model(folder_name)\n",
|
"file_id = get_file_id_by_model(folder_name)\n",
|
||||||
"print(folder_name, file_id)"
|
"print(folder_name, file_id)"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "PWZ8fet_YoJm",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "PWZ8fet_YoJm"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# download and extract model files\n",
|
"# download and extract model files\n",
|
||||||
"os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
|
"os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
|
||||||
"os.system('unzip {}'.format(folder_name))\n",
|
"os.system('unzip {}'.format(folder_name))\n",
|
||||||
"!ls"
|
"!ls"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "3_nypQVEv-hn",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "3_nypQVEv-hn"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from torch.utils.data import DataLoader\n",
|
"from torch.utils.data import DataLoader\n",
|
||||||
"import torchvision.transforms as transforms\n",
|
"import torchvision.transforms as transforms\n",
|
||||||
"from torchvision import datasets"
|
"from torchvision import datasets"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "lDfbL3w_Z0Od",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "lDfbL3w_Z0Od"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||||||
"print(\"Using device:\", device)"
|
"print(\"Using device:\", device)"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "IQMIryc6LjQd",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "IQMIryc6LjQd"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
|
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
|
||||||
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
|
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
|
||||||
"config"
|
"config"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "BfIPl0G6_RrT",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "BfIPl0G6_RrT"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
|
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
|
||||||
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
|
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
|
||||||
@ -188,17 +163,17 @@
|
|||||||
" test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
|
" test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
|
||||||
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
|
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
|
||||||
" return train_loader, test_loader"
|
" return train_loader, test_loader"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "a18lPD-tIle6",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "a18lPD-tIle6"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def _load_resnet_model(checkpoints_folder):\n",
|
"def _load_resnet_model(checkpoints_folder):\n",
|
||||||
" # Load the neural net module\n",
|
" # Load the neural net module\n",
|
||||||
@ -213,15 +188,13 @@
|
|||||||
" model.load_state_dict(state_dict)\n",
|
" model.load_state_dict(state_dict)\n",
|
||||||
" model = model.to(device)\n",
|
" model = model.to(device)\n",
|
||||||
" return model"
|
" return model"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "5nf4rDtWLjRE",
|
"colab_type": "text",
|
||||||
"colab_type": "text"
|
"id": "5nf4rDtWLjRE"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## Protocol #2 Logisitc Regression"
|
"## Protocol #2 Logisitc Regression"
|
||||||
@ -229,11 +202,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "7jjSxmDnHNQz",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "7jjSxmDnHNQz"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"class ResNetFeatureExtractor(object):\n",
|
"class ResNetFeatureExtractor(object):\n",
|
||||||
" def __init__(self, checkpoints_folder):\n",
|
" def __init__(self, checkpoints_folder):\n",
|
||||||
@ -263,43 +238,43 @@
|
|||||||
" X_test_feature, y_test = self._inference(test_loader)\n",
|
" X_test_feature, y_test = self._inference(test_loader)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return X_train_feature, y_train, X_test_feature, y_test"
|
" return X_train_feature, y_train, X_test_feature, y_test"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "kghx1govJq5_",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "kghx1govJq5_"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
|
"resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "S_JcznxVJ1Xj",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "S_JcznxVJ1Xj"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
|
"X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "oftbHXcdLjRM",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "oftbHXcdLjRM"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import torch.nn as nn\n",
|
"import torch.nn as nn\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -311,17 +286,17 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" def forward(self, x):\n",
|
" def forward(self, x):\n",
|
||||||
" return self.model(x)"
|
" return self.model(x)"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "Ks73ePLtNWeV",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "Ks73ePLtNWeV"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"class LogiticRegressionEvaluator(object):\n",
|
"class LogiticRegressionEvaluator(object):\n",
|
||||||
" def __init__(self, n_features, n_classes):\n",
|
" def __init__(self, n_features, n_classes):\n",
|
||||||
@ -408,37 +383,60 @@
|
|||||||
" print(\"--------------\")\n",
|
" print(\"--------------\")\n",
|
||||||
" print(\"Done training\")\n",
|
" print(\"Done training\")\n",
|
||||||
" print(\"Best accuracy:\", best_accuracy)"
|
" print(\"Best accuracy:\", best_accuracy)"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": 0,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "NE716m7SOkaK",
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
"colab": {}
|
"id": "NE716m7SOkaK"
|
||||||
},
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
|
"log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
|
"log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
|
||||||
],
|
]
|
||||||
"execution_count": 0,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"metadata": {
|
|
||||||
"id": "_GC0a14uWRr6",
|
|
||||||
"colab_type": "code",
|
|
||||||
"colab": {}
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
""
|
|
||||||
],
|
|
||||||
"execution_count": 0,
|
"execution_count": 0,
|
||||||
"outputs": []
|
"metadata": {
|
||||||
|
"colab": {},
|
||||||
|
"colab_type": "code",
|
||||||
|
"id": "_GC0a14uWRr6"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"include_colab_link": true,
|
||||||
|
"name": "mini-batch-logistic-regression-evaluator.ipynb",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python [conda env:image]",
|
||||||
|
"language": "python",
|
||||||
|
"name": "conda-env-image-py"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.6.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 1
|
||||||
}
|
}
|
25
simclr.py
25
simclr.py
@ -34,6 +34,23 @@ def _save_config_file(model_checkpoints_folder, args):
|
|||||||
yaml.dump(args, outfile, default_flow_style=False)
|
yaml.dump(args, outfile, default_flow_style=False)
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(output, target, topk=(1,)):
|
||||||
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||||
|
with torch.no_grad():
|
||||||
|
maxk = max(topk)
|
||||||
|
batch_size = target.size(0)
|
||||||
|
|
||||||
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
|
pred = pred.t()
|
||||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in topk:
|
||||||
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||||
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class SimCLR(object):
|
class SimCLR(object):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -97,10 +114,10 @@ class SimCLR(object):
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
if n_iter % self.args.log_every_n_steps == 0:
|
if n_iter % self.args.log_every_n_steps == 0:
|
||||||
predictions = torch.argmax(logits, dim=1)
|
top1, top5 = accuracy(logits, labels, topk=(1,5))
|
||||||
acc = 100 * (predictions == labels).float().mean()
|
|
||||||
self.writer.add_scalar('loss', loss, global_step=n_iter)
|
self.writer.add_scalar('loss', loss, global_step=n_iter)
|
||||||
self.writer.add_scalar('acc/top1', acc, global_step=n_iter)
|
self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
|
||||||
|
self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
|
||||||
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
|
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
@ -108,7 +125,7 @@ class SimCLR(object):
|
|||||||
# warmup for the first 10 epochs
|
# warmup for the first 10 epochs
|
||||||
if epoch_counter >= 10:
|
if epoch_counter >= 10:
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {acc}")
|
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
|
||||||
|
|
||||||
logging.info("Training has finished.")
|
logging.info("Training has finished.")
|
||||||
# save model checkpoints
|
# save model checkpoints
|
||||||
|
Loading…
x
Reference in New Issue
Block a user