mirror of
https://github.com/sthalles/SimCLR.git
synced 2025-06-03 15:03:00 +08:00
Major refactor, small fixes
This commit is contained in:
parent
2c9536f731
commit
d0112ed55b
@ -2,6 +2,7 @@ from torchvision.transforms import transforms
|
||||
from data_aug.gaussian_blur import GaussianBlur
|
||||
from torchvision import transforms, datasets
|
||||
from data_aug.view_generator import ContrastiveLearningViewGenerator
|
||||
from exceptions.exceptions import InvalidDatasetSelection
|
||||
|
||||
|
||||
class ContrastiveLearningDataset:
|
||||
@ -33,5 +34,9 @@ class ContrastiveLearningDataset:
|
||||
n_views),
|
||||
download=True)}
|
||||
|
||||
dataset = valid_datasets.get(name, 'Invalid dataset option.')()
|
||||
return dataset
|
||||
try:
|
||||
dataset_fn = valid_datasets[name]
|
||||
except KeyError:
|
||||
raise InvalidDatasetSelection()
|
||||
else:
|
||||
return dataset_fn()
|
||||
|
@ -4,3 +4,7 @@ class BaseSimCLRException(Exception):
|
||||
|
||||
class InvalidBackboneError(BaseSimCLRException):
|
||||
"""Raised when the choice of backbone Convnet is invalid."""
|
||||
|
||||
|
||||
class InvalidDatasetSelection(BaseSimCLRException):
|
||||
"""Raised when the choice of dataset is invalid."""
|
||||
|
@ -1,37 +1,10 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "pytorch",
|
||||
"language": "python",
|
||||
"name": "pytorch"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.6"
|
||||
},
|
||||
"colab": {
|
||||
"name": "mini-batch-logistic-regression-evaluator.ipynb",
|
||||
"provenance": [],
|
||||
"include_colab_link": true
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/master/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
@ -39,11 +12,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "YUemQib7ZE4D",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "YUemQib7ZE4D"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import sys\n",
|
||||
@ -56,30 +31,30 @@
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"from sklearn import preprocessing\n",
|
||||
"import importlib.util"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "WSgRE1CcLqdS",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "WSgRE1CcLqdS"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install gdown"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "NOIJEui1ZziV",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "NOIJEui1ZziV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_file_id_by_model(folder_name):\n",
|
||||
" file_id = {'resnet-18_40-epochs': '1c4eVon0sUd-ChVhH6XMpF6nCngNJsAPk',\n",
|
||||
@ -88,92 +63,92 @@
|
||||
" 'resnet-50_80-epochs': '1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35',\n",
|
||||
" 'resnet-18_100-epochs':'1aZ12TITXnajZ6QWmS_SDm8Sp8gXNbeCQ'}\n",
|
||||
" return file_id.get(folder_name, \"Model not found.\")"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "G7YMxsvEZMrX",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "G7YMxsvEZMrX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"folder_name = 'resnet-50_40-epochs'\n",
|
||||
"file_id = get_file_id_by_model(folder_name)\n",
|
||||
"print(folder_name, file_id)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "PWZ8fet_YoJm",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "PWZ8fet_YoJm"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# download and extract model files\n",
|
||||
"os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
|
||||
"os.system('unzip {}'.format(folder_name))\n",
|
||||
"!ls"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "3_nypQVEv-hn",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "3_nypQVEv-hn"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"from torchvision import datasets"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "lDfbL3w_Z0Od",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "lDfbL3w_Z0Od"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||||
"print(\"Using device:\", device)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "IQMIryc6LjQd",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "IQMIryc6LjQd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
|
||||
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
|
||||
"config"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "BfIPl0G6_RrT",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "BfIPl0G6_RrT"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_stl10_data_loaders(download, shuffle=False, batch_size=128):\n",
|
||||
" train_dataset = datasets.STL10('./data', split='train', download=download,\n",
|
||||
@ -188,17 +163,17 @@
|
||||
" test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
|
||||
" num_workers=0, drop_last=False, shuffle=shuffle)\n",
|
||||
" return train_loader, test_loader"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "a18lPD-tIle6",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "a18lPD-tIle6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _load_resnet_model(checkpoints_folder):\n",
|
||||
" # Load the neural net module\n",
|
||||
@ -213,15 +188,13 @@
|
||||
" model.load_state_dict(state_dict)\n",
|
||||
" model = model.to(device)\n",
|
||||
" return model"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "5nf4rDtWLjRE",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "5nf4rDtWLjRE"
|
||||
},
|
||||
"source": [
|
||||
"## Protocol #2 Logisitc Regression"
|
||||
@ -229,11 +202,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "7jjSxmDnHNQz",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "7jjSxmDnHNQz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ResNetFeatureExtractor(object):\n",
|
||||
" def __init__(self, checkpoints_folder):\n",
|
||||
@ -263,43 +238,43 @@
|
||||
" X_test_feature, y_test = self._inference(test_loader)\n",
|
||||
"\n",
|
||||
" return X_train_feature, y_train, X_test_feature, y_test"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "kghx1govJq5_",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "kghx1govJq5_"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "S_JcznxVJ1Xj",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "S_JcznxVJ1Xj"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "oftbHXcdLjRM",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "oftbHXcdLjRM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.nn as nn\n",
|
||||
"\n",
|
||||
@ -311,17 +286,17 @@
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.model(x)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "Ks73ePLtNWeV",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "Ks73ePLtNWeV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class LogiticRegressionEvaluator(object):\n",
|
||||
" def __init__(self, n_features, n_classes):\n",
|
||||
@ -408,37 +383,60 @@
|
||||
" print(\"--------------\")\n",
|
||||
" print(\"Done training\")\n",
|
||||
" print(\"Best accuracy:\", best_accuracy)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"id": "NE716m7SOkaK",
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
"id": "NE716m7SOkaK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
|
||||
"\n",
|
||||
"log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "_GC0a14uWRr6",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "_GC0a14uWRr6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"include_colab_link": true,
|
||||
"name": "mini-batch-logistic-regression-evaluator.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:image]",
|
||||
"language": "python",
|
||||
"name": "conda-env-image-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
25
simclr.py
25
simclr.py
@ -34,6 +34,23 @@ def _save_config_file(model_checkpoints_folder, args):
|
||||
yaml.dump(args, outfile, default_flow_style=False)
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
class SimCLR(object):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -97,10 +114,10 @@ class SimCLR(object):
|
||||
self.optimizer.step()
|
||||
|
||||
if n_iter % self.args.log_every_n_steps == 0:
|
||||
predictions = torch.argmax(logits, dim=1)
|
||||
acc = 100 * (predictions == labels).float().mean()
|
||||
top1, top5 = accuracy(logits, labels, topk=(1,5))
|
||||
self.writer.add_scalar('loss', loss, global_step=n_iter)
|
||||
self.writer.add_scalar('acc/top1', acc, global_step=n_iter)
|
||||
self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
|
||||
self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
|
||||
self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
|
||||
|
||||
n_iter += 1
|
||||
@ -108,7 +125,7 @@ class SimCLR(object):
|
||||
# warmup for the first 10 epochs
|
||||
if epoch_counter >= 10:
|
||||
self.scheduler.step()
|
||||
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {acc}")
|
||||
logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
|
||||
|
||||
logging.info("Training has finished.")
|
||||
# save model checkpoints
|
||||
|
Loading…
x
Reference in New Issue
Block a user