mirror of https://github.com/sthalles/SimCLR.git
457 lines
18 KiB
Plaintext
457 lines
18 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from model import Encoder, ResNet18\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from torchvision import datasets\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 256\n",
|
|
"out_dim = 64"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Training images\n",
|
|
"(5000, 96, 96, 3)\n",
|
|
"(5000,)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"X_train = np.fromfile('data/stl10_binary/train_X.bin', dtype=np.uint8)\n",
|
|
"y_train = np.fromfile('data/stl10_binary/train_y.bin', dtype=np.uint8)\n",
|
|
"\n",
|
|
"X_train = np.reshape(X_train, (-1, 3, 96, 96))\n",
|
|
"X_train = np.transpose(X_train, (0, 3, 2, 1))\n",
|
|
"print(\"Training images\")\n",
|
|
"print(X_train.shape)\n",
|
|
"print(y_train.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Test images\n",
|
|
"(8000, 96, 96, 3)\n",
|
|
"(8000,)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"X_test = np.fromfile('data/stl10_binary/test_X.bin', dtype=np.uint8)\n",
|
|
"y_test = np.fromfile('data/stl10_binary/test_y.bin', dtype=np.uint8)\n",
|
|
"\n",
|
|
"X_test = np.reshape(X_test, (-1, 3, 96, 96))\n",
|
|
"X_test = np.transpose(X_test, (0, 3, 2, 1))\n",
|
|
"print(\"Test images\")\n",
|
|
"print(X_test.shape)\n",
|
|
"print(y_test.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Test protocol #1 PCA features"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.decomposition import PCA\n",
|
|
"from sklearn.linear_model import LogisticRegression"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PCA features\n",
|
|
"(5000, 64)\n",
|
|
"(8000, 64)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"pca = PCA(n_components=out_dim)\n",
|
|
"X_train_pca = pca.fit_transform(X_train.reshape((X_train.shape[0],-1)))\n",
|
|
"X_test_pca = pca.transform(X_test.reshape((X_test.shape[0],-1)))\n",
|
|
"\n",
|
|
"print(\"PCA features\")\n",
|
|
"print(X_train_pca.shape)\n",
|
|
"print(X_test_pca.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PCA feature evaluation\n",
|
|
"Train score: 0.3984\n",
|
|
"Test score: 0.353125\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|
"\n",
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"clf = LogisticRegression(random_state=0).fit(X_train_pca, y_train)\n",
|
|
"print(\"PCA feature evaluation\")\n",
|
|
"print(\"Train score:\", clf.score(X_train_pca, y_train))\n",
|
|
"print(\"Test score:\", clf.score(X_test_pca, y_test))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Files already downloaded and verified\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"data_augment = transforms.Compose([transforms.RandomResizedCrop(96),\n",
|
|
" transforms.ToTensor()])\n",
|
|
"\n",
|
|
"train_dataset = datasets.STL10('data', split='train', download=True, transform=data_augment)\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=1, drop_last=False, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Files already downloaded and verified\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"test_dataset = datasets.STL10('data', split='test', download=True, transform=data_augment)\n",
|
|
"test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=1, drop_last=False, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"ResNet18(\n",
|
|
" (features): Sequential(\n",
|
|
" (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
|
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (2): ReLU(inplace=True)\n",
|
|
" (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
|
" (4): Sequential(\n",
|
|
" (0): BasicBlock(\n",
|
|
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (relu): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" (1): BasicBlock(\n",
|
|
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (relu): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (5): Sequential(\n",
|
|
" (0): BasicBlock(\n",
|
|
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (relu): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (downsample): Sequential(\n",
|
|
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
|
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): BasicBlock(\n",
|
|
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (relu): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (6): Sequential(\n",
|
|
" (0): BasicBlock(\n",
|
|
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (relu): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (downsample): Sequential(\n",
|
|
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
|
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): BasicBlock(\n",
|
|
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (relu): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (7): Sequential(\n",
|
|
" (0): BasicBlock(\n",
|
|
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (relu): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (downsample): Sequential(\n",
|
|
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
|
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (1): BasicBlock(\n",
|
|
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (relu): ReLU(inplace=True)\n",
|
|
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
|
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
|
" )\n",
|
|
" (l1): Linear(in_features=512, out_features=512, bias=True)\n",
|
|
" (l2): Linear(in_features=512, out_features=64, bias=True)\n",
|
|
")\n",
|
|
"odict_keys(['features.0.weight', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.1.num_batches_tracked', 'features.4.0.conv1.weight', 'features.4.0.bn1.weight', 'features.4.0.bn1.bias', 'features.4.0.bn1.running_mean', 'features.4.0.bn1.running_var', 'features.4.0.bn1.num_batches_tracked', 'features.4.0.conv2.weight', 'features.4.0.bn2.weight', 'features.4.0.bn2.bias', 'features.4.0.bn2.running_mean', 'features.4.0.bn2.running_var', 'features.4.0.bn2.num_batches_tracked', 'features.4.1.conv1.weight', 'features.4.1.bn1.weight', 'features.4.1.bn1.bias', 'features.4.1.bn1.running_mean', 'features.4.1.bn1.running_var', 'features.4.1.bn1.num_batches_tracked', 'features.4.1.conv2.weight', 'features.4.1.bn2.weight', 'features.4.1.bn2.bias', 'features.4.1.bn2.running_mean', 'features.4.1.bn2.running_var', 'features.4.1.bn2.num_batches_tracked', 'features.5.0.conv1.weight', 'features.5.0.bn1.weight', 'features.5.0.bn1.bias', 'features.5.0.bn1.running_mean', 'features.5.0.bn1.running_var', 'features.5.0.bn1.num_batches_tracked', 'features.5.0.conv2.weight', 'features.5.0.bn2.weight', 'features.5.0.bn2.bias', 'features.5.0.bn2.running_mean', 'features.5.0.bn2.running_var', 'features.5.0.bn2.num_batches_tracked', 'features.5.0.downsample.0.weight', 'features.5.0.downsample.1.weight', 'features.5.0.downsample.1.bias', 'features.5.0.downsample.1.running_mean', 'features.5.0.downsample.1.running_var', 'features.5.0.downsample.1.num_batches_tracked', 'features.5.1.conv1.weight', 'features.5.1.bn1.weight', 'features.5.1.bn1.bias', 'features.5.1.bn1.running_mean', 'features.5.1.bn1.running_var', 'features.5.1.bn1.num_batches_tracked', 'features.5.1.conv2.weight', 'features.5.1.bn2.weight', 'features.5.1.bn2.bias', 'features.5.1.bn2.running_mean', 'features.5.1.bn2.running_var', 'features.5.1.bn2.num_batches_tracked', 'features.6.0.conv1.weight', 'features.6.0.bn1.weight', 'features.6.0.bn1.bias', 'features.6.0.bn1.running_mean', 'features.6.0.bn1.running_var', 'features.6.0.bn1.num_batches_tracked', 'features.6.0.conv2.weight', 'features.6.0.bn2.weight', 'features.6.0.bn2.bias', 'features.6.0.bn2.running_mean', 'features.6.0.bn2.running_var', 'features.6.0.bn2.num_batches_tracked', 'features.6.0.downsample.0.weight', 'features.6.0.downsample.1.weight', 'features.6.0.downsample.1.bias', 'features.6.0.downsample.1.running_mean', 'features.6.0.downsample.1.running_var', 'features.6.0.downsample.1.num_batches_tracked', 'features.6.1.conv1.weight', 'features.6.1.bn1.weight', 'features.6.1.bn1.bias', 'features.6.1.bn1.running_mean', 'features.6.1.bn1.running_var', 'features.6.1.bn1.num_batches_tracked', 'features.6.1.conv2.weight', 'features.6.1.bn2.weight', 'features.6.1.bn2.bias', 'features.6.1.bn2.running_mean', 'features.6.1.bn2.running_var', 'features.6.1.bn2.num_batches_tracked', 'features.7.0.conv1.weight', 'features.7.0.bn1.weight', 'features.7.0.bn1.bias', 'features.7.0.bn1.running_mean', 'features.7.0.bn1.running_var', 'features.7.0.bn1.num_batches_tracked', 'features.7.0.conv2.weight', 'features.7.0.bn2.weight', 'features.7.0.bn2.bias', 'features.7.0.bn2.running_mean', 'features.7.0.bn2.running_var', 'features.7.0.bn2.num_batches_tracked', 'features.7.0.downsample.0.weight', 'features.7.0.downsample.1.weight', 'features.7.0.downsample.1.bias', 'features.7.0.downsample.1.running_mean', 'features.7.0.downsample.1.running_var', 'features.7.0.downsample.1.num_batches_tracked', 'features.7.1.conv1.weight', 'features.7.1.bn1.weight', 'features.7.1.bn1.bias', 'features.7.1.bn1.running_mean', 'features.7.1.bn1.running_var', 'features.7.1.bn1.num_batches_tracked', 'features.7.1.conv2.weight', 'features.7.1.bn2.weight', 'features.7.1.bn2.bias', 'features.7.1.bn2.running_mean', 'features.7.1.bn2.running_var', 'features.7.1.bn2.num_batches_tracked', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<All keys matched successfully>"
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model = ResNet18(out_dim=out_dim)\n",
|
|
"model.eval()\n",
|
|
"print(model)\n",
|
|
"\n",
|
|
"state_dict = torch.load('model/checkpoint.pth')\n",
|
|
"print(state_dict.keys())\n",
|
|
"\n",
|
|
"model.load_state_dict(state_dict)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Protocol #2 Linear separability evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Train features\n",
|
|
"(5000, 512)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"X_train_feature = []\n",
|
|
"\n",
|
|
"for step, (batch_x, batch_y) in enumerate(train_loader):\n",
|
|
" features, _ = model(batch_x)\n",
|
|
" X_train_feature.extend(features.detach().numpy())\n",
|
|
" \n",
|
|
"X_train_feature = np.array(X_train_feature)\n",
|
|
"\n",
|
|
"print(\"Train features\")\n",
|
|
"print(X_train_feature.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Test features\n",
|
|
"(8000, 512)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"X_test_feature = []\n",
|
|
"\n",
|
|
"for step, (batch_x, batch_y) in enumerate(test_loader):\n",
|
|
" features, _ = model(batch_x)\n",
|
|
" X_test_feature.extend(features.detach().numpy())\n",
|
|
" \n",
|
|
"X_test_feature = np.array(X_test_feature)\n",
|
|
"\n",
|
|
"print(\"Test features\")\n",
|
|
"print(X_test_feature.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/thalles/anaconda3/envs/pytorch/lib/python3.6/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|
"\n",
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.linear_model import LogisticRegression\n",
|
|
"clf = LogisticRegression(random_state=0).fit(X_train_feature, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"SimCLR feature evaluation\n",
|
|
"Train score: 0.7444\n",
|
|
"Test score: 0.62625\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"SimCLR feature evaluation\")\n",
|
|
"print(\"Train score:\", clf.score(X_train_feature, y_train))\n",
|
|
"print(\"Test score:\", clf.score(X_test_feature, y_test))\n",
|
|
"# SimCLR feature evaluation\n",
|
|
"# Train score: 0.5298\n",
|
|
"# Test score: 0.52075"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "pytorch",
|
|
"language": "python",
|
|
"name": "pytorch"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|