mirror of https://github.com/sthalles/SimCLR.git
Merge remote-tracking branch 'origin/master'
commit
6597622376
|
@ -83,7 +83,6 @@ Check the [ | SimCLR | [ResNet-18](https://drive.google.com/open?id=1L0yoeY9i2mzDcj69P4slTWb-cfr3PyoT) | 512 | 256 | 80 | 72.9% |
|
||||
| KNN | SimCLR | ResNet-18 | 512 | 256 | 80 | 69.8% |
|
||||
| Logistic Regression (Adam) | SimCLR | [ResNet-18](https://drive.google.com/open?id=1SgMCbzp1fXoqUFDJcnlb7hmwqjUvGusd) | 512 | 512 | 80 | 75.4% |
|
||||
| Logistic Regression (Adam) | SimCLR | [ResNet-50](https://drive.google.com/open?id=1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q) | 2048 | 128 | 40 | 74.6% |
|
||||
| Logistic Regression (Adam) | SimCLR | [ResNet-50](https://drive.google.com/open?id=1is1wkBRccHdhSKQnPUTQoaFkVNSaCb35) | 2048 | 128 | 80 | 77.3% |
|
||||
|
||||
|
||||
|
|
|
@ -65,32 +65,13 @@
|
|||
"metadata": {
|
||||
"id": "WSgRE1CcLqdS",
|
||||
"colab_type": "code",
|
||||
"outputId": "3bd80a41-005c-416d-9476-a0dc14921ab0",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 163
|
||||
}
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"!pip install gdown"
|
||||
],
|
||||
"execution_count": 25,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
|
||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.28.1)\n",
|
||||
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.21.0)\n",
|
||||
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.12.0)\n",
|
||||
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2019.11.28)\n",
|
||||
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.8)\n",
|
||||
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -115,38 +96,22 @@
|
|||
"metadata": {
|
||||
"id": "G7YMxsvEZMrX",
|
||||
"colab_type": "code",
|
||||
"outputId": "430bc8d7-6e3c-44c5-eb8f-7f3ca24c4172",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
}
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"folder_name = 'resnet-50_40-epochs'\n",
|
||||
"file_id = get_file_id_by_model(folder_name)\n",
|
||||
"print(folder_name, file_id)"
|
||||
],
|
||||
"execution_count": 27,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"resnet-50_40-epochs 1TZqBNTFCsO-mxAiR-zJeyupY-J2gA27Q\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "PWZ8fet_YoJm",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 72
|
||||
},
|
||||
"outputId": "2871e598-a429-4cfa-cd96-40850003e638"
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"# download and extract model files\n",
|
||||
|
@ -154,18 +119,8 @@
|
|||
"os.system('unzip {}'.format(folder_name))\n",
|
||||
"!ls"
|
||||
],
|
||||
"execution_count": 28,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"data\t\t resnet-50_40-epochs.zip sample_data\n",
|
||||
"log_regression.pth resnet-50_80-epochs\n",
|
||||
"resnet-50_40-epochs resnet-50_80-epochs.zip\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -187,68 +142,29 @@
|
|||
"metadata": {
|
||||
"id": "lDfbL3w_Z0Od",
|
||||
"colab_type": "code",
|
||||
"outputId": "d148eb48-8e56-4af5-8c2b-c7821e2c7149",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
}
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||||
"print(\"Using device:\", device)"
|
||||
],
|
||||
"execution_count": 30,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using device: cuda\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "IQMIryc6LjQd",
|
||||
"colab_type": "code",
|
||||
"outputId": "9020c91b-d9ad-4d46-c181-cd394061df0d",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 217
|
||||
}
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"checkpoints_folder = os.path.join(folder_name, 'checkpoints')\n",
|
||||
"config = yaml.load(open(os.path.join(checkpoints_folder, \"config.yaml\"), \"r\"))\n",
|
||||
"config"
|
||||
],
|
||||
"execution_count": 31,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'batch_size': 256,\n",
|
||||
" 'dataset': {'input_shape': '(96,96,3)',\n",
|
||||
" 'num_workers': 0,\n",
|
||||
" 's': 1,\n",
|
||||
" 'valid_size': 0.05},\n",
|
||||
" 'epochs': 40,\n",
|
||||
" 'eval_every_n_epochs': 1,\n",
|
||||
" 'fine_tune_from': 'None',\n",
|
||||
" 'log_every_n_steps': 50,\n",
|
||||
" 'loss': {'temperature': 0.5, 'use_cosine_similarity': True},\n",
|
||||
" 'model': {'base_model': 'resnet50', 'out_dim': 128}}"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"execution_count": 31
|
||||
}
|
||||
]
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -355,53 +271,26 @@
|
|||
"metadata": {
|
||||
"id": "kghx1govJq5_",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"outputId": "36040306-9730-4781-eaef-dc9018e75176"
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"resnet_feature_extractor = ResNetFeatureExtractor(checkpoints_folder)"
|
||||
],
|
||||
"execution_count": 35,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Feature extractor: resnet50\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "S_JcznxVJ1Xj",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 90
|
||||
},
|
||||
"outputId": "aea5aa1c-d78c-4df2-86da-97efa484e093"
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"X_train_feature, y_train, X_test_feature, y_test = resnet_feature_extractor.get_resnet_features()"
|
||||
],
|
||||
"execution_count": 36,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Files already downloaded and verified\n",
|
||||
"Files already downloaded and verified\n",
|
||||
"Features shape (5000, 2048)\n",
|
||||
"Features shape (8000, 2048)\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -527,31 +416,15 @@
|
|||
"metadata": {
|
||||
"id": "NE716m7SOkaK",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 108
|
||||
},
|
||||
"outputId": "87de8f71-4312-4d76-a1f3-7af5bbe0e9ba"
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"log_regressor_evaluator = LogiticRegressionEvaluator(n_features=X_train_feature.shape[1], n_classes=10)\n",
|
||||
"\n",
|
||||
"log_regressor_evaluator.train(X_train_feature, y_train, X_test_feature, y_test)"
|
||||
],
|
||||
"execution_count": 41,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Standard Scaling Normalizer\n",
|
||||
"Sampled weight decay: 0.00017782794100389227\n",
|
||||
"--------------\n",
|
||||
"Done training\n",
|
||||
"Best accuracy: 73.6625\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
|
Loading…
Reference in New Issue