Major refactor, small fixes

This commit is contained in:
Thalles Silva 2021-01-17 20:07:59 -03:00
parent 2c9536f731
commit d0112ed55b
4 changed files with 473 additions and 449 deletions

View File

@ -2,6 +2,7 @@ from torchvision.transforms import transforms
from data_aug.gaussian_blur import GaussianBlur from data_aug.gaussian_blur import GaussianBlur
from torchvision import transforms, datasets from torchvision import transforms, datasets
from data_aug.view_generator import ContrastiveLearningViewGenerator from data_aug.view_generator import ContrastiveLearningViewGenerator
from exceptions.exceptions import InvalidDatasetSelection
class ContrastiveLearningDataset: class ContrastiveLearningDataset:
@ -33,5 +34,9 @@ class ContrastiveLearningDataset:
n_views), n_views),
download=True)} download=True)}
dataset = valid_datasets.get(name, 'Invalid dataset option.')() try:
return dataset dataset_fn = valid_datasets[name]
except KeyError:
raise InvalidDatasetSelection()
else:
return dataset_fn()

View File

@ -4,3 +4,7 @@ class BaseSimCLRException(Exception):
class InvalidBackboneError(BaseSimCLRException): class InvalidBackboneError(BaseSimCLRException):
"""Raised when the choice of backbone Convnet is invalid.""" """Raised when the choice of backbone Convnet is invalid."""
class InvalidDatasetSelection(BaseSimCLRException):
"""Raised when the choice of dataset is invalid."""

View File

@ -1,37 +1,10 @@
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"colab": {
"name": "mini-batch-logistic-regression-evaluator.ipynb",
"provenance": [],
"include_colab_link": true
},
"accelerator": "GPU"
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "view-in-github", "colab_type": "text",
"colab_type": "text" "id": "view-in-github"
}, },
"source": [ "source": [
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" "<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
@ -39,11 +12,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "YUemQib7ZE4D", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "YUemQib7ZE4D"
}, },
"outputs": [],
"source": [ "source": [
"import torch\n", "import torch\n",
"import sys\n", "import sys\n",
@ -56,30 +31,30 @@
"from sklearn.linear_model import LogisticRegression\n", "from sklearn.linear_model import LogisticRegression\n",
"from sklearn import preprocessing\n", "from sklearn import preprocessing\n",
"import importlib.util" "import importlib.util"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "WSgRE1CcLqdS", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "WSgRE1CcLqdS"
}, },
"outputs": [],
"source": [ "source": [
"!pip install gdown" "!pip install gdown"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "NOIJEui1ZziV", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "NOIJEui1ZziV"
}, },
"outputs": [],
"source": [ "source": [
"def get_file_id_by_model(folder_name):\n", "def get_file_id_by_model(folder_name):\n",
" file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n", " file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n",
@ -88,92 +63,92 @@
" 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n", " 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n",
" 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n", " 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n",
" return file_id.get(folder_name, \"Model not found.\")" " return file_id.get(folder_name, \"Model not found.\")"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "G7YMxsvEZMrX", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "G7YMxsvEZMrX"
}, },
"outputs": [],
"source": [ "source": [
"folder_name = 'resnet-50_40-epochs'\n", "folder_name = 'resnet-50_40-epochs'\n",
"file_id = get_file_id_by_model(folder_name)\n", "file_id = get_file_id_by_model(folder_name)\n",
"print(folder_name, file_id)" "print(folder_name, file_id)"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "PWZ8fet_YoJm", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "PWZ8fet_YoJm"
}, },
"outputs": [],
"source": [ "source": [
"# download and extract model files\n", "# download and extract model files\n",
"os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n", "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
"os.system('unzip {}'.format(folder_name))\n", "os.system('unzip {}'.format(folder_name))\n",
"!ls" "!ls"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "3_nypQVEv-hn", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "3_nypQVEv-hn"
}, },
"outputs": [],
"source": [ "source": [
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"import torchvision.transforms as transforms\n", "import torchvision.transforms as transforms\n",
"from torchvision import datasets" "from torchvision import datasets"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "lDfbL3w_Z0Od", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "lDfbL3w_Z0Od"
}, },
"outputs": [],
"source": [ "source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(\"Using device:\", device)" "print(\"Using device:\", device)"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "IQMIryc6LjQd", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "IQMIryc6LjQd"
}, },
"outputs": [],
"source": [ "source": [
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n", "checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n", "config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
"config" "config"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "BfIPl0G6_RrT", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "BfIPl0G6_RrT"
}, },
"outputs": [],
"source": [ "source": [
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n", "def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
" train_dataset = datasets.STL10('./data', split='train', download=download,\n", " train_dataset = datasets.STL10('./data', split='train', download=download,\n",
@ -188,17 +163,17 @@
" test_loader = DataLoader(test_dataset, batch_size=batch_size,\n", " test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
" num_workers=0, drop_last=False, shuffle=shuffle)\n", " num_workers=0, drop_last=False, shuffle=shuffle)\n",
" return train_loader, test_loader" " return train_loader, test_loader"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "a18lPD-tIle6", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "a18lPD-tIle6"
}, },
"outputs": [],
"source": [ "source": [
"def _load_resnet_model(checkpoints_folder):\n", "def _load_resnet_model(checkpoints_folder):\n",
" # Load the neural net module\n", " # Load the neural net module\n",
@ -213,15 +188,13 @@
" model.load_state_dict(state_dict)\n", " model.load_state_dict(state_dict)\n",
" model = model.to(device)\n", " model = model.to(device)\n",
" return model" " return model"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "5nf4rDtWLjRE", "colab_type": "text",
"colab_type": "text" "id": "5nf4rDtWLjRE"
}, },
"source": [ "source": [
"## Protocol #2 Logisitc Regression" "## Protocol #2 Logisitc Regression"
@ -229,11 +202,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "7jjSxmDnHNQz", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "7jjSxmDnHNQz"
}, },
"outputs": [],
"source": [ "source": [
"class ResNetFeatureExtractor(object):\n", "class ResNetFeatureExtractor(object):\n",
" def __init__(self, checkpoints_folder):\n", " def __init__(self, checkpoints_folder):\n",
@ -263,43 +238,43 @@
" X_test_feature, y_test = self._inference(test_loader)\n", " X_test_feature, y_test = self._inference(test_loader)\n",
"\n", "\n",
" return X_train_feature, y_train, X_test_feature, y_test" " return X_train_feature, y_train, X_test_feature, y_test"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "kghx1govJq5_", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "kghx1govJq5_"
}, },
"outputs": [],
"source": [ "source": [
"resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)" "resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "S_JcznxVJ1Xj", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "S_JcznxVJ1Xj"
}, },
"outputs": [],
"source": [ "source": [
"X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()" "X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "oftbHXcdLjRM", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "oftbHXcdLjRM"
}, },
"outputs": [],
"source": [ "source": [
"import torch.nn as nn\n", "import torch.nn as nn\n",
"\n", "\n",
@ -311,17 +286,17 @@
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x):\n",
" return self.model(x)" " return self.model(x)"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "Ks73ePLtNWeV", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "Ks73ePLtNWeV"
}, },
"outputs": [],
"source": [ "source": [
"class LogiticRegressionEvaluator(object):\n", "class LogiticRegressionEvaluator(object):\n",
" def __init__(self, n_features, n_classes):\n", " def __init__(self, n_features, n_classes):\n",
@ -408,37 +383,60 @@
" print(\"--------------\")\n", " print(\"--------------\")\n",
" print(\"Done training\")\n", " print(\"Done training\")\n",
" print(\"Best accuracy:\", best_accuracy)" " print(\"Best accuracy:\", best_accuracy)"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0,
"metadata": { "metadata": {
"id": "NE716m7SOkaK", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "NE716m7SOkaK"
}, },
"outputs": [],
"source": [ "source": [
"log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n", "log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
"\n", "\n",
"log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)" "log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
], ]
"execution_count": 0,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"metadata": {
"id": "_GC0a14uWRr6",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0, "execution_count": 0,
"outputs": [] "metadata": {
"colab": {},
"colab_type": "code",
"id": "_GC0a14uWRr6"
},
"outputs": [],
"source": []
} }
] ],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"name": "mini-batch-logistic-regression-evaluator.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python [conda env:image]",
"language": "python",
"name": "conda-env-image-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
} }

View File

@ -34,6 +34,23 @@ def _save_config_file(model_checkpoints_folder, args):
yaml.dump(args, outfile, default_flow_style=False) yaml.dump(args, outfile, default_flow_style=False)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class SimCLR(object): class SimCLR(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -97,10 +114,10 @@ class SimCLR(object):
self.optimizer.step() self.optimizer.step()
if n_iter % self.args.log_every_n_steps == 0: if n_iter % self.args.log_every_n_steps == 0:
predictions = torch.argmax(logits, dim=1) top1, top5 = accuracy(logits, labels, topk=(1,5))
acc = 100 * (predictions == labels).float().mean()
self.writer.add_scalar('loss', loss, global_step=n_iter) self.writer.add_scalar('loss', loss, global_step=n_iter)
self.writer.add_scalar('acc/top1', acc, global_step=n_iter) self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter) self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
n_iter += 1 n_iter += 1
@ -108,7 +125,7 @@ class SimCLR(object):
# warmup for the first 10 epochs # warmup for the first 10 epochs
if epoch_counter >= 10: if epoch_counter >= 10:
self.scheduler.step() self.scheduler.step()
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {acc}") logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
logging.info("Training has finished.") logging.info("Training has finished.")
# save model checkpoints # save model checkpoints