mmpretrain/docs/en/tutorials/MMClassification_tools.ipynb

1250 lines
639 KiB
Plaintext
Raw Normal View History

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "MMClassification_tools.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "XjQxmm04iTx4",
"tags": []
},
"source": [
"<a href=\"https://colab.research.google.com/github/open-mmlab/mmclassification/blob/master/docs/tutorials/MMClassification_tools.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4z0JDgisPRr-"
},
"source": [
"# MMClassification tools tutorial on Colab\n",
"\n",
"In this tutorial, we will introduce the following content:\n",
"\n",
"* How to install MMCls\n",
"* Prepare data\n",
"* Prepare the config file\n",
"* Train and test model with shell command"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "inm7Ciy5PXrU"
},
"source": [
"## Install MMClassification\n",
"\n",
"Before using MMClassification, we need to prepare the environment with the following steps:\n",
"\n",
"1. Install Python, CUDA, C/C++ compiler and git\n",
"2. Install PyTorch (CUDA version)\n",
"3. Install mmcv\n",
"4. Clone mmcls source code from GitHub and install it\n",
"\n",
"Because this tutorial is on Google Colab, and the basic environment has been completed, we can skip the first two steps."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDOxbcDvPbNk"
},
"source": [
"### Check environment"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c6MbAw10iUJI",
"outputId": "8d3d6b53-c69b-4425-ce0c-bfb8d31ab971"
},
"source": [
"%cd /content"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4IyFL3MaiYRu",
"outputId": "c46dc718-27de-418b-da17-9d5a717e8424"
},
"source": [
"!pwd"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DMw7QwvpiiUO",
"outputId": "0d852285-07c4-48d3-e537-4a51dea04d10"
},
"source": [
"# Check nvcc version\n",
"!nvcc -V"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"nvcc: NVIDIA (R) Cuda compiler driver\n",
"Copyright (c) 2005-2020 NVIDIA Corporation\n",
"Built on Mon_Oct_12_20:09:46_PDT_2020\n",
"Cuda compilation tools, release 11.1, V11.1.105\n",
"Build cuda_11.1.TC455_06.29190527_0\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4VIBU7Fain4D",
"outputId": "fb34a7b6-8eda-4180-e706-1bf67d1a6fd4"
},
"source": [
"# Check GCC version\n",
"!gcc --version"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
"Copyright (C) 2017 Free Software Foundation, Inc.\n",
"This is free software; see the source for copying conditions. There is NO\n",
"warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "24lDLCqFisZ9",
"outputId": "304ad2f7-a9bb-4441-d25b-09b5516ccd74"
},
"source": [
"# Check PyTorch installation\n",
"import torch, torchvision\n",
"print(torch.__version__)\n",
"print(torch.cuda.is_available())"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1.9.0+cu111\n",
"True\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2aZNLUwizBs"
},
"source": [
"### Install MMCV\n",
"\n",
"MMCV is the basic package of all OpenMMLab packages. We have pre-built wheels on Linux, so we can download and install them directly.\n",
"\n",
"Please pay attention to PyTorch and CUDA versions to match the wheel.\n",
"\n",
"In the above steps, we have checked the version of PyTorch and CUDA, and they are 1.9.0 and 11.1 respectively, so we need to choose the corresponding wheel.\n",
"\n",
"In addition, we can also install the full version of mmcv (mmcv-full). It includes full features and various CUDA ops out of the box, but needs a longer time to build."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nla40LrLi7oo",
"outputId": "a17d50d6-05b7-45d6-c3fb-6a2507415cf5"
},
"source": [
"# Install mmcv\n",
"!pip install mmcv -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html\n",
"# !pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in links: https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html\n",
"Collecting mmcv\n",
" Downloading mmcv-1.3.15.tar.gz (352 kB)\n",
"\u001b[K |████████████████████████████████| 352 kB 12.8 MB/s \n",
"\u001b[?25hCollecting addict\n",
" Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcv) (1.19.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from mmcv) (21.0)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from mmcv) (7.1.2)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from mmcv) (3.13)\n",
"Collecting yapf\n",
" Downloading yapf-0.31.0-py2.py3-none-any.whl (185 kB)\n",
"\u001b[K |████████████████████████████████| 185 kB 49.3 MB/s \n",
"\u001b[?25hRequirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->mmcv) (2.4.7)\n",
"Building wheels for collected packages: mmcv\n",
" Building wheel for mmcv (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for mmcv: filename=mmcv-1.3.15-py2.py3-none-any.whl size=509835 sha256=13b8c5d70c29029916f661f2dc9b773b74a9ea4e0758491a7b5c15c798efaa61\n",
" Stored in directory: /root/.cache/pip/wheels/b2/f4/4e/8f6d2dd2bef6b7eb8c89aa0e5d61acd7bff60aaf3d4d4b29b0\n",
"Successfully built mmcv\n",
"Installing collected packages: yapf, addict, mmcv\n",
"Successfully installed addict-2.4.0 mmcv-1.3.15 yapf-0.31.0\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GDTUrYvXjlRb"
},
"source": [
"### Clone and install MMClassification\n",
"\n",
"Next, we clone the latest mmcls repository from GitHub and install it."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bwme6tWHjl5s",
"outputId": "7e2d54c8-b134-405a-b014-194da1708776"
},
"source": [
"# Clone mmcls repository\n",
"!git clone https://github.com/open-mmlab/mmclassification.git\n",
"%cd mmclassification/\n",
"\n",
"# Install MMClassification from source\n",
"!pip install -e . "
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'mmclassification'...\n",
"remote: Enumerating objects: 4152, done.\u001b[K\n",
"remote: Counting objects: 100% (994/994), done.\u001b[K\n",
"remote: Compressing objects: 100% (579/579), done.\u001b[K\n",
"remote: Total 4152 (delta 476), reused 761 (delta 398), pack-reused 3158\u001b[K\n",
"Receiving objects: 100% (4152/4152), 8.21 MiB | 19.02 MiB/s, done.\n",
"Resolving deltas: 100% (2518/2518), done.\n",
"/content/mmclassification\n",
"Obtaining file:///content/mmclassification\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (3.2.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (1.19.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (21.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (0.10.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (1.3.2)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (2.4.7)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (2.8.2)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->mmcls==0.16.0) (1.15.0)\n",
"Installing collected packages: mmcls\n",
" Running setup.py develop for mmcls\n",
"Successfully installed mmcls-0.16.0\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hFg_oSG4j3zB",
"outputId": "1cc74bac-f918-4f0e-bf56-9f13447dfce1"
},
"source": [
"# Check MMClassification installation\n",
"import mmcls\n",
"print(mmcls.__version__)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.16.0\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HCOHRp3iV5Xk"
},
"source": [
"## Prepare data"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XHCHnKb_Qd3P",
"outputId": "35496010-ee57-4e72-af00-2af55dc80f47"
},
"source": [
"# Download the dataset (cats & dogs dataset)\n",
"!wget https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0 -O cats_dogs_dataset.zip\n",
"!mkdir -p data\n",
"!unzip -q cats_dogs_dataset.zip -d ./data/"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2021-10-21 02:47:54-- https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0\n",
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.67.18, 2620:100:6020:18::a27d:4012\n",
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.67.18|:443... connected.\n",
"HTTP request sent, awaiting response... 301 Moved Permanently\n",
"Location: /s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip [following]\n",
"--2021-10-21 02:47:54-- https://www.dropbox.com/s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip\n",
"Reusing existing connection to www.dropbox.com:443.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://uc88da1070f63f9a78ee48b59098.dl.dropboxusercontent.com/cd/0/inline/BYb26ayxWasysNPC1wSer1N9YqdOShCMIBzSIQ5NKaIoKQQ47lxZ3y7DkjKNLrYiSHkA_KgTE47_9jUHaHW79JqDtcSNEAO3unPfo8bPwsxaQUHqo97L_RjsSBhWg4HZStWRbLIJUl5WUOtpETbSQtvD/file# [following]\n",
"--2021-10-21 02:47:54-- https://uc88da1070f63f9a78ee48b59098.dl.dropboxusercontent.com/cd/0/inline/BYb26ayxWasysNPC1wSer1N9YqdOShCMIBzSIQ5NKaIoKQQ47lxZ3y7DkjKNLrYiSHkA_KgTE47_9jUHaHW79JqDtcSNEAO3unPfo8bPwsxaQUHqo97L_RjsSBhWg4HZStWRbLIJUl5WUOtpETbSQtvD/file\n",
"Resolving uc88da1070f63f9a78ee48b59098.dl.dropboxusercontent.com (uc88da1070f63f9a78ee48b59098.dl.dropboxusercontent.com)... 162.125.67.15, 2620:100:6020:15::a27d:400f\n",
"Connecting to uc88da1070f63f9a78ee48b59098.dl.dropboxusercontent.com (uc88da1070f63f9a78ee48b59098.dl.dropboxusercontent.com)|162.125.67.15|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: /cd/0/inline2/BYbEOCLrcXNg9qXvYXbyZZ0cgv3fSQ1vs-iqDCz24_84Fgz_2Z5SkserjAUpmYgty-eQkchAlzxQPbgzayZnie5yCipe42WVTChJJiIQ6m5x7GxgWJOn6_5QP3eRbFuYyrc1yV61BKlYuCJDHH0eyNaN8paR6bjevwMJ7Alip-gvf3c9JfjJmMgZrzcpknENyaI62FSgxFkX-Kc-FS41RYQadnMfUmhZCfMrFDSzTcmRprDiC9hQ-zJkcW_kbjI0whA1ZLQ-OG9-8Qf7jn8qd4g_tQLneL8X44qOUX4hRs2LE23g4n0jz8DeNt8KZ48WhGs8_20rBIgHH0dut3OjHF5DZMI8dVyHFAiJGyxOknZ5aCfImtz6MGgHDwbiipkICxk/file [following]\n",
"--2021-10-21 02:47:55-- https://uc88da1070f63f9a78ee48b59098.dl.dropboxusercontent.com/cd/0/inline2/BYbEOCLrcXNg9qXvYXbyZZ0cgv3fSQ1vs-iqDCz24_84Fgz_2Z5SkserjAUpmYgty-eQkchAlzxQPbgzayZnie5yCipe42WVTChJJiIQ6m5x7GxgWJOn6_5QP3eRbFuYyrc1yV61BKlYuCJDHH0eyNaN8paR6bjevwMJ7Alip-gvf3c9JfjJmMgZrzcpknENyaI62FSgxFkX-Kc-FS41RYQadnMfUmhZCfMrFDSzTcmRprDiC9hQ-zJkcW_kbjI0whA1ZLQ-OG9-8Qf7jn8qd4g_tQLneL8X44qOUX4hRs2LE23g4n0jz8DeNt8KZ48WhGs8_20rBIgHH0dut3OjHF5DZMI8dVyHFAiJGyxOknZ5aCfImtz6MGgHDwbiipkICxk/file\n",
"Reusing existing connection to uc88da1070f63f9a78ee48b59098.dl.dropboxusercontent.com:443.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 228802825 (218M) [application/zip]\n",
"Saving to: cats_dogs_dataset.zip\n",
"\n",
"cats_dogs_dataset.z 100%[===================>] 218.20M 16.9MB/s in 13s \n",
"\n",
"2021-10-21 02:48:08 (16.9 MB/s) - cats_dogs_dataset.zip saved [228802825/228802825]\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e4t2P2aTQokX"
},
"source": [
"**After downloading and extraction,** we get \"Cats and Dogs Dataset\" and the file structure is as below:\n",
"```\n",
"data/cats_dogs_dataset\n",
"├── classes.txt\n",
"├── test.txt\n",
"├── val.txt\n",
"├── training_set\n",
"│ ├── training_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.1.jpg\n",
"│ │ │ ├── cat.2.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.2.jpg\n",
"│ │ │ ├── dog.3.jpg\n",
"│ │ │ ├── ...\n",
"├── val_set\n",
"│ ├── val_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.3.jpg\n",
"│ │ │ ├── cat.5.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.1.jpg\n",
"│ │ │ ├── dog.6.jpg\n",
"│ │ │ ├── ...\n",
"├── test_set\n",
"│ ├── test_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.4001.jpg\n",
"│ │ │ ├── cat.4002.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.4001.jpg\n",
"│ │ │ ├── dog.4002.jpg\n",
"│ │ │ ├── ...\n",
"```\n",
"\n",
"You can use shell command `tree data/cats_dogs_dataset` to check the structure."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 297
},
"id": "46tyHTdtQy_Z",
"outputId": "6124a89e-03eb-4917-a0bf-df6a391eb280"
},
"source": [
"# Pick an image and visualize it\n",
"from PIL import Image\n",
"Image.open('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEYCAIAAABp9FyZAAEAAElEQVR4nJT96ZMkWXIfCKrqe88uv+KOvCuzju7qrmo00A00QZAcoZAznPmwIkP+p/NhZWVnZTAHSBALAmig+u6ursrKMzIyDj/teofqflB3S89qckTWpCTKM8Lc3ezZ0+unP1XF737+KQAgGEQDQCAoIgD04MGDul7f3Nx0fTsalbPZJCa/XC5dOQ4hGGMAoGka55y1tu/7LMu89865zWbz+PHj+Xy+2WxGo1G3rokIAJgZEa3N9L3C4Jwzxngfu64TkbIsq6pqu40xJqXU922MkYiMMSKS56Uh61zubGmtM5QjGkTMDZRl2fd9CKGoSkTx3gthSgkRBCGllFIKnAAAEbGXPM/1q7vOM3OM0ZBLKcWYUkrM/PHH33n58mWe5yklnlW3t7e5M6fHJ7PRaLOY3769zAxNsvz06PDJwwe5ofnV1eNHD2+uru7du3O7aeu6Nha97/7n//l/fvHimQ8dIi4Wi/V6vVk34/H0+voa0Xzy8XdfvXrVtt39+/c//vjjt2/fPn32TZ7nDBKZY4x37997+/btn/zpj9++fftX//E//vCHP1yv1yWS98FaK2i+//3vv3nztutDluWcQMjEwNfXN/cfPHrz5o0IVlVlQvfpp5+mlH72s5+NZ9MY48nZ6Xy5XK1Wn37/e3fu3Hnx6uVvf/vbPoQ7d+6cn5/317ezg8lqteq6phqVV1eXm83q0aNHTdOklFKSlASBrM2yLHM2ny9eiAgB5nleFBWhjZG7kMAUybibTXdxs1yHaMspOtf7aAiMMdZaRERhRHTOOUNd13FKIfQxRokJUay1xpibxTzGKCLWWiJCkizL8jy/vb1JKT169Kiu133fHx8f397etm07nU67rqubdex9nueTycQY5JjeXF4cTmfHx8fG4Pz2drVaTSaTO3furNb1YrHw3mdZllKq6xoRx+NxWZbM7JyLMW42G+dcVVUppc1mlee5916/6/T0tGmaZ8+e3bt3DxHLsgSAvu+dcyIyn8/J8PHxIVFs13Mi/z/+9/9qOrWzcX7n7lEKTQi9QLKICAAg8K2j7/uUEiKKSAih67rEIYRQ94uiKEII0+n06urq5OQkhAAAMUa9iJQSAFRVZa1dr9ciwszGGN36RDal5L1HIGYGABHRaxARXe5vXwrAVoaBmVl2B4CIiE/RGOO9DyGQNYMQhhCICA3q240xiGj0f4jW2rqui6Lquq4sS9/HoiiOjo5fvXo1mUx+/OMf379//2/+5m8ePnz4y9fPfvSjH/3t3/y1M/aPP//8F/+0qKpqeXuTCSBi0zTZeDSbzbquy/P86uqqmB5WVeUyU9dS1/XV1VVMvqqqs7Oz5XK5XC6rajydTlerze3tbUppPB4bY4qi0KUmIu9749xqtdKlOz8/J6I/+7M/CyE450Lb6foj4dOnT7OsCCGEEIu8stYSwsHBASBba5m5qoo337yq6/WdO3em0/Hd+3frun79+uXZ2Vni8PTpVwcH0zx3JydHTdN88+zrsspzgL7vjTHT6bQo87atQ+ibpgkhWGuzLBPBGFIIqWkakM7lGcfEzMzsvUdIPiYfOJ9UQGSMIWsNCxElkARCAqr4AACFRcR7j8LMzCnF6FNKFinLbFEUWZYdnZ40TdN1XQghhOBD1/e9bjYVifV6HUIYjUZElBeubVuBlGWZI6OLEH3o+x4AyrIUkfl8EUI4Pj4moouLCxa01qrCreu6russy0RkPB7Xda17O8syIkopxRiZue97773esu7YsixVjJ1zRDRsY2stS4wxWgPWWmt1JwMiJh9i4pQSS6Rhl8v7h/ceEVVy9r9YHwYijkYjABiNRnoPw14XEbWBjx8/zvNcf6NXP0ia3ttwV2ru9CnuiyUi6tbUW8L3D/3GuDtCivsHMydhERARRgAAIkLEBNLH4FNcbtYhJUHMigKtuV0uNm2zbuou+Murq+9//rnNsuvb2zzPz8/PP//88+l0aoxxzumNq6LRVTo6Ouq67ujoaL1e3717dzwep5SyLBuNRrPZjJlXq9VqtXLO6T7I81y1qd71crns+17dAedc13UoYoy5vLw8OTl5/vz5D37wg/Ozs6PDwzzPmVkfZmbp5cuXKSUU0MtAxLars8xJ4uOTQ/VN7t69+0//9E+67Z49e6ZKJ8Z4cHDw4ZMnP//5z5tNrdvx0f0Hq/miqde31zd92xhCjoEADdJ6ueqaNoVokMo8K/LMEkYfurZGREEQER9DCMHHwMyM0HVt27ad72OMMaXOt13X6T3GGL33IXh1Urz3XddlWWYt7R7rdqsw87Nnz54/f/7ixYu3b99u6lVKyVqbZS7LsizLdFPpNiADRJQ4OOemo/F0Oi2KAoD188fjMQC0bRtDyLJMDd1yuby+vo4xZllmjCGi8Xh8dnZ2584dAPDe13UdY9RtrIrDWqsbTLdTCAERp9Op6ik9Qfe8un4JpO9DCMHaLM9LvVRh8N7HGEUQwZCezcxpdwwibq0djUZlWaoa0C/QbeScCyGMx2MRybJMnQoR6ft+vV5fXV0h4ocffnh6ejoajVSRqIx1Xafr+L6PalWwU0qDJRzEbBBLRCQilVh98U4g3700QoiI+pCstcZYEBSGFDlF7r0XgN57Y20fvMszm7nRZFxUZe/94dFRluf/5e//7uPvfPKTP/9n63pzfHz8/Pnzf/2v//WDBw9+9atf6dqpsiyKIqWkznnTNKenp9PpVNVn0zSbzeby8lK90BDCb37zG2vt2dlZ3/d1XRdFoU+rbdurq6vlcpmEp9Pp0dGRcw4AiqLo2857//vf/q5erfURTMbjsiy7rh3WIYTgMiPCSGJAdspbPnz85OT0qOub2XTcdvXzF98kDin0PnQI7H2XGTo7OyGU2/n1Yn5jCI4OZ02z6ft+uVyq8VksFl3XuMxYR71v62a92azUThZFUZalc86nkFIMKaaUYkqIaDNXlCUQCoIxJitcVhY2z4xzLs8QkZm93+p0IrKWdD/oHTnnNMbRM+/cuXN0dKSO1aCRdds4ZxCxKAqXGR+6EEJd11VejMtK/UlriQCNMS6zR7ODvu/bth6NRqPRqGmaul7nuauqSpcxxjgej+/fv3/37t3pdHp9fb1er9u2TSk553RzWmuLotDtp0ZPhVDV7mCo9F7UzAij97Hvg6AGVsAhMbP3UZIYpNwVpHel4qcirnKoFl9XRKVFhTDP89VqhYjz+Xwymcznc0TUB6OGmJnrum6aRgVYTZBzTkVCpUtvYJCbQa50lQch3JdD3h3f0hr6sIczhRCAEI2xmTHGGAN7QktEYMgVeQKZHh4kEJtnPsWsKNQe5lXJCJfXV6t644ocDLVtu1gs1Dm/vr6u63qz2YiIOkjOOdU7RHRwcPDJJ58Q0ZMnTz755JOyLL/88stvvvkGAL773e/OZjO9AFVn6l8YY5yzIty2zXg8un//3tnZ6cHBjAhR+PDo4JunX2eZ+9//978cFcXLZ8/KLJtOp23b6uecHh2v18uUUgpxvV7H5HNrEofet5PJ+N6dO9Px+M2b13/ywz/65puvr64u//zPf3JyeHB5eVGV+Xqz/Onf/5c//qPPLSGBfP797/3iZ19Yg9aStSTAXdfU9TrGOBqNzs/PsyyLMa5Wq+Vy7n3nMluWeVnmiCiEYBAIhRhIyFrnXDmqqqoqRuVkNj04OJhOp7OD6eHxgbWGCNWl6vs+hF631mJ5u9lsvPcAbK3VHQIAR0dHBwcHk+moKDPdIQCit8/MbdsmDjHGuq5VIZZlqVYhRh99UIMxLisAcM6URWGtTSl436FAmRcff/zxeDxWpVMUxWQySSldXl7e3Nzoe3UXqQEnIjUY+vtB8FRpMrPamGG3ExGA+rEphMBJhJETxMgxcIoCQIho33dHh+BQ+r4vikz3SkoJgMkYRHRZpjAMEZ2cnCyXyyzLxuNxjHEw6HVdLxYL9bJ0rVXwjDEpSQhBRPKsGELBwdskIpb0LXdUzwkhgKAxKaWESAxRBEXEUhRJzFFEEggyi0hiF
"text/plain": [
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x280 at 0x7EFEAE2EE8D0>"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "My5Z6p7pQ3UC"
},
"source": [
"### Support new dataset\n",
"\n",
"We have two methods to support a new dataset in MMClassification.\n",
"\n",
"The simplest method is to re-organize the new dataset as the format of a dataset supported officially (like ImageNet). And the other method is to create a new dataset class, and more details are in [the docs](https://mmclassification.readthedocs.io/en/latest/tutorials/new_dataset.html#an-example-of-customized-dataset).\n",
"\n",
"In this tutorial, for convenience, we have re-organized the cats & dogs dataset as the format of ImageNet."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P335gKt9Q5U-"
},
"source": [
"Besides image files, it also includes the following files:\n",
"\n",
"1. A class list file, and every line is a class.\n",
" ```\n",
" cats\n",
" dogs\n",
" ```\n",
"2. Training / Validation / Test annotation files. And every line includes an file path and the corresponding label.\n",
"\n",
" ```\n",
" ...\n",
" cats/cat.3769.jpg 0\n",
" cats/cat.882.jpg 0\n",
" ...\n",
" dogs/dog.3881.jpg 1\n",
" dogs/dog.3377.jpg 1\n",
" ...\n",
" ```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BafQ7ijBQ8N_"
},
"source": [
"## Train and test model with shell commands\n",
"\n",
"You can use shell commands provided by MMClassification to do the following task:\n",
"\n",
"1. Train a model\n",
"2. Fine-tune a model\n",
"3. Test a model\n",
"4. Inference with a model\n",
"\n",
"The procedure to train and fine-tune a model is almost the same. And we have introduced how to do these tasks with Python API. In the following, we will introduce how to do them with shell commands. More details are in [the docs](https://mmclassification.readthedocs.io/en/latest/getting_started.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Aj5cGMihURrZ"
},
"source": [
"### Fine-tune a model\n",
"\n",
"The steps to fine-tune a model are as below:\n",
"\n",
"1. Prepare the custom dataset.\n",
"2. Create a new config file of the task.\n",
"3. Start training task by shell commands.\n",
"\n",
"We have finished the first step, and then we will introduce the next two steps.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WBBV3aG79ZH5"
},
"source": [
"#### Create a new config file\n",
"\n",
"To reuse the common parts of different config files, we support inheriting multiple base config files. For example, to fine-tune a MobileNetV2 model, the new config file can create the model's basic structure by inheriting `configs/_base_/models/mobilenet_v2_1x.py`.\n",
"\n",
"According to the common practice, we usually split whole configs into four parts: model, dataset, learning rate schedule, and runtime. Configs of each part are saved into one file in the `configs/_base_` folder. \n",
"\n",
"And then, when creating a new config file, we can select some parts to inherit and only override some different configs.\n",
"\n",
"The head of the final config file should look like:\n",
"\n",
"```python\n",
"_base_ = [\n",
" '../_base_/models/mobilenet_v2_1x.py',\n",
" '../_base_/schedules/imagenet_bs256_epochstep.py',\n",
" '../_base_/default_runtime.py'\n",
"]\n",
"```\n",
"\n",
"Here, because the dataset configs are almost brand new, we don't need to inherit any dataset config file.\n",
"\n",
"Of course, you can also create an entire config file without inheritance, like `configs/mnist/lenet5.py`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_UV3oBhLRG8B"
},
"source": [
"After that, we only need to set the part of configs we want to modify, because the inherited configs will be merged to the final configs."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8QfM4qBeWIQh",
"outputId": "a826f0cf-2633-4a9a-e49b-4be7eca5e3a0"
},
"source": [
"%%writefile configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\n",
"_base_ = [\n",
" '../_base_/models/mobilenet_v2_1x.py',\n",
" '../_base_/schedules/imagenet_bs256_epochstep.py',\n",
" '../_base_/default_runtime.py'\n",
"]\n",
"\n",
"# ---- Model configs ----\n",
"# Here we use init_cfg to load pre-trained model.\n",
"# In this way, only the weights of backbone will be loaded.\n",
"# And modify the num_classes to match our dataset.\n",
"\n",
"model = dict(\n",
" backbone=dict(\n",
" init_cfg = dict(\n",
" type='Pretrained', \n",
" checkpoint='https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', \n",
" prefix='backbone')\n",
" ),\n",
" head=dict(\n",
" num_classes=2,\n",
" topk = (1, )\n",
" ))\n",
"\n",
"# ---- Dataset configs ----\n",
"# We re-organized the dataset as ImageNet format.\n",
"dataset_type = 'ImageNet'\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.050, 106.438],\n",
" std=[58.577, 57.310, 57.437],\n",
" to_rgb=True)\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224, backend='pillow'),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(type='Normalize', **img_norm_cfg),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1), backend='pillow'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(type='Normalize', **img_norm_cfg),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
"]\n",
"data = dict(\n",
" # Specify the batch size and number of workers in each GPU.\n",
" # Please configure it according to your hardware.\n",
" samples_per_gpu=32,\n",
" workers_per_gpu=2,\n",
" # Specify the training dataset type and path\n",
" train=dict(\n",
" type=dataset_type,\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=train_pipeline),\n",
" # Specify the validation dataset type and path\n",
" val=dict(\n",
" type=dataset_type,\n",
" data_prefix='data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file='data/cats_dogs_dataset/val.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=test_pipeline),\n",
" # Specify the test dataset type and path\n",
" test=dict(\n",
" type=dataset_type,\n",
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file='data/cats_dogs_dataset/test.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=test_pipeline))\n",
"\n",
"# Specify evaluation metric\n",
"evaluation = dict(metric='accuracy', metric_options={'topk': (1, )})\n",
"\n",
"# ---- Schedule configs ----\n",
"# Usually in fine-tuning, we need a smaller learning rate and less training epochs.\n",
"# Specify the learning rate\n",
"optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"# Set the learning rate scheduler\n",
"lr_config = dict(policy='step', step=1, gamma=0.1)\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"\n",
"# ---- Runtime configs ----\n",
"# Output training log every 10 iterations.\n",
"log_config = dict(interval=10)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Writing configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "chLX7bL3RP2F"
},
"source": [
"#### Use shell command to start fine-tuning\n",
"\n",
"We use `tools/train.py` to fine-tune a model:\n",
"\n",
"```shell\n",
"python tools/train.py ${CONFIG_FILE} [optional arguments]\n",
"```\n",
"\n",
"And if you want to specify another folder to save log files and checkpoints, use the argument `--work_dir ${YOUR_WORK_DIR}`.\n",
"\n",
"If you want to ensure reproducibility, use the argument `--seed ${SEED}` to set a random seed. And the argument `--deterministic` can enable the deterministic option in cuDNN to further ensure reproducibility, but it may reduce the training speed.\n",
"\n",
"Here we use the `MobileNetV2` model and cats & dogs dataset as an example:\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gbFGR4SBRUYN",
"outputId": "3412752c-433f-43c5-82a9-3495d1cd797a"
},
"source": [
"!python tools/train.py \\\n",
" configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py \\\n",
" --work-dir work_dirs/mobilenet_v2_1x_cats_dogs \\\n",
" --seed 0 \\\n",
" --deterministic"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"2021-10-21 02:48:20,030 - mmcls - INFO - Environment info:\n",
"------------------------------------------------------------\n",
"sys.platform: linux\n",
"Python: 3.7.12 (default, Sep 10 2021, 00:21:48) [GCC 7.5.0]\n",
"CUDA available: True\n",
"GPU 0: Tesla K80\n",
"CUDA_HOME: /usr/local/cuda\n",
"NVCC: Build cuda_11.1.TC455_06.29190527_0\n",
"GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
"PyTorch: 1.9.0+cu111\n",
"PyTorch compiling details: PyTorch built with:\n",
" - GCC 7.3\n",
" - C++ Version: 201402\n",
" - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n",
" - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n",
" - OpenMP 201511 (a.k.a. OpenMP 4.5)\n",
" - NNPACK is enabled\n",
" - CPU capability usage: AVX2\n",
" - CUDA Runtime 11.1\n",
" - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n",
" - CuDNN 8.0.5\n",
" - Magma 2.5.2\n",
" - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n",
"\n",
"TorchVision: 0.10.0+cu111\n",
"OpenCV: 4.1.2\n",
"MMCV: 1.3.15\n",
"MMCV Compiler: n/a\n",
"MMCV CUDA Compiler: n/a\n",
"MMClassification: 0.16.0+77a3834\n",
"------------------------------------------------------------\n",
"\n",
"2021-10-21 02:48:20,030 - mmcls - INFO - Distributed training: False\n",
"2021-10-21 02:48:20,688 - mmcls - INFO - Config:\n",
"model = dict(\n",
" type='ImageClassifier',\n",
" backbone=dict(\n",
" type='MobileNetV2',\n",
" widen_factor=1.0,\n",
" init_cfg=dict(\n",
" type='Pretrained',\n",
" checkpoint=\n",
" 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth',\n",
" prefix='backbone')),\n",
" neck=dict(type='GlobalAveragePooling'),\n",
" head=dict(\n",
" type='LinearClsHead',\n",
" num_classes=2,\n",
" in_channels=1280,\n",
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
" topk=(1, )))\n",
"optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"lr_config = dict(policy='step', gamma=0.1, step=1)\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"checkpoint_config = dict(interval=1)\n",
"log_config = dict(interval=10, hooks=[dict(type='TextLoggerHook')])\n",
"dist_params = dict(backend='nccl')\n",
"log_level = 'INFO'\n",
"load_from = None\n",
"resume_from = None\n",
"workflow = [('train', 1)]\n",
"dataset_type = 'ImageNet'\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.05, 106.438], std=[58.577, 57.31, 57.437], to_rgb=True)\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224, backend='pillow'),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1), backend='pillow'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
"]\n",
"data = dict(\n",
" samples_per_gpu=32,\n",
" workers_per_gpu=2,\n",
" train=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224, backend='pillow'),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
" ]),\n",
" val=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file='data/cats_dogs_dataset/val.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1), backend='pillow'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ]),\n",
" test=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file='data/cats_dogs_dataset/test.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1), backend='pillow'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ]))\n",
"evaluation = dict(metric='accuracy', metric_options=dict(topk=(1, )))\n",
"work_dir = 'work_dirs/mobilenet_v2_1x_cats_dogs'\n",
"gpu_ids = range(0, 1)\n",
"\n",
"2021-10-21 02:48:20,689 - mmcls - INFO - Set random seed to 0, deterministic: True\n",
"2021-10-21 02:48:20,854 - mmcls - INFO - initialize MobileNetV2 with init_cfg {'type': 'Pretrained', 'checkpoint': 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', 'prefix': 'backbone'}\n",
"2021-10-21 02:48:20,855 - mmcv - INFO - load backbone in model from: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"Use load_from_http loader\n",
"Downloading: \"https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\" to /root/.cache/torch/hub/checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"100% 13.5M/13.5M [00:01<00:00, 9.54MB/s]\n",
"2021-10-21 02:48:23,564 - mmcls - INFO - initialize LinearClsHead with init_cfg {'type': 'Normal', 'layer': 'Linear', 'std': 0.01}\n",
"2021-10-21 02:48:38,767 - mmcls - INFO - Start running, host: root@992cc7e7be60, work_dir: /content/mmclassification/work_dirs/mobilenet_v2_1x_cats_dogs\n",
"2021-10-21 02:48:38,767 - mmcls - INFO - Hooks will be executed in the following order:\n",
"before_run:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(NORMAL ) CheckpointHook \n",
"(LOW ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_train_epoch:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(LOW ) IterTimerHook \n",
"(LOW ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_train_iter:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(LOW ) IterTimerHook \n",
"(LOW ) EvalHook \n",
" -------------------- \n",
"after_train_iter:\n",
"(ABOVE_NORMAL) OptimizerHook \n",
"(NORMAL ) CheckpointHook \n",
"(LOW ) IterTimerHook \n",
"(LOW ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"after_train_epoch:\n",
"(NORMAL ) CheckpointHook \n",
"(LOW ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_val_epoch:\n",
"(LOW ) IterTimerHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_val_iter:\n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_val_iter:\n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_val_epoch:\n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"2021-10-21 02:48:38,768 - mmcls - INFO - workflow: [('train', 1)], max: 2 epochs\n",
"2021-10-21 02:48:44,261 - mmcls - INFO - Epoch [1][10/201]\tlr: 5.000e-03, eta: 0:03:29, time: 0.533, data_time: 0.257, memory: 1709, loss: 0.3917\n",
"2021-10-21 02:48:46,950 - mmcls - INFO - Epoch [1][20/201]\tlr: 5.000e-03, eta: 0:02:33, time: 0.269, data_time: 0.019, memory: 1709, loss: 0.3508\n",
"2021-10-21 02:48:49,618 - mmcls - INFO - Epoch [1][30/201]\tlr: 5.000e-03, eta: 0:02:12, time: 0.266, data_time: 0.021, memory: 1709, loss: 0.3955\n",
"2021-10-21 02:48:52,271 - mmcls - INFO - Epoch [1][40/201]\tlr: 5.000e-03, eta: 0:02:00, time: 0.266, data_time: 0.018, memory: 1709, loss: 0.2485\n",
"2021-10-21 02:48:54,984 - mmcls - INFO - Epoch [1][50/201]\tlr: 5.000e-03, eta: 0:01:53, time: 0.272, data_time: 0.019, memory: 1709, loss: 0.4196\n",
"2021-10-21 02:48:57,661 - mmcls - INFO - Epoch [1][60/201]\tlr: 5.000e-03, eta: 0:01:46, time: 0.266, data_time: 0.019, memory: 1709, loss: 0.4994\n",
"2021-10-21 02:49:00,341 - mmcls - INFO - Epoch [1][70/201]\tlr: 5.000e-03, eta: 0:01:41, time: 0.268, data_time: 0.018, memory: 1709, loss: 0.4372\n",
"2021-10-21 02:49:03,035 - mmcls - INFO - Epoch [1][80/201]\tlr: 5.000e-03, eta: 0:01:37, time: 0.270, data_time: 0.019, memory: 1709, loss: 0.3179\n",
"2021-10-21 02:49:05,731 - mmcls - INFO - Epoch [1][90/201]\tlr: 5.000e-03, eta: 0:01:32, time: 0.269, data_time: 0.020, memory: 1709, loss: 0.3175\n",
"2021-10-21 02:49:08,404 - mmcls - INFO - Epoch [1][100/201]\tlr: 5.000e-03, eta: 0:01:29, time: 0.268, data_time: 0.019, memory: 1709, loss: 0.3412\n",
"2021-10-21 02:49:11,106 - mmcls - INFO - Epoch [1][110/201]\tlr: 5.000e-03, eta: 0:01:25, time: 0.270, data_time: 0.016, memory: 1709, loss: 0.2985\n",
"2021-10-21 02:49:13,776 - mmcls - INFO - Epoch [1][120/201]\tlr: 5.000e-03, eta: 0:01:21, time: 0.267, data_time: 0.018, memory: 1709, loss: 0.2778\n",
"2021-10-21 02:49:16,478 - mmcls - INFO - Epoch [1][130/201]\tlr: 5.000e-03, eta: 0:01:18, time: 0.270, data_time: 0.021, memory: 1709, loss: 0.2229\n",
"2021-10-21 02:49:19,130 - mmcls - INFO - Epoch [1][140/201]\tlr: 5.000e-03, eta: 0:01:15, time: 0.266, data_time: 0.018, memory: 1709, loss: 0.2318\n",
"2021-10-21 02:49:21,812 - mmcls - INFO - Epoch [1][150/201]\tlr: 5.000e-03, eta: 0:01:12, time: 0.268, data_time: 0.019, memory: 1709, loss: 0.2333\n",
"2021-10-21 02:49:24,514 - mmcls - INFO - Epoch [1][160/201]\tlr: 5.000e-03, eta: 0:01:08, time: 0.270, data_time: 0.017, memory: 1709, loss: 0.2783\n",
"2021-10-21 02:49:27,184 - mmcls - INFO - Epoch [1][170/201]\tlr: 5.000e-03, eta: 0:01:05, time: 0.267, data_time: 0.017, memory: 1709, loss: 0.2132\n",
"2021-10-21 02:49:29,875 - mmcls - INFO - Epoch [1][180/201]\tlr: 5.000e-03, eta: 0:01:02, time: 0.269, data_time: 0.021, memory: 1709, loss: 0.2096\n",
"2021-10-21 02:49:32,546 - mmcls - INFO - Epoch [1][190/201]\tlr: 5.000e-03, eta: 0:00:59, time: 0.267, data_time: 0.019, memory: 1709, loss: 0.1729\n",
"2021-10-21 02:49:35,200 - mmcls - INFO - Epoch [1][200/201]\tlr: 5.000e-03, eta: 0:00:56, time: 0.265, data_time: 0.017, memory: 1709, loss: 0.1969\n",
"2021-10-21 02:49:35,247 - mmcls - INFO - Saving checkpoint at 1 epochs\n",
"[ ] 0/1601, elapsed: 0s, ETA:[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)\n",
"[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)\n",
"[>>] 1601/1601, 173.2 task/s, elapsed: 9s, ETA: 0s2021-10-21 02:49:44,587 - mmcls - INFO - Epoch(val) [1][51]\taccuracy_top-1: 95.6277\n",
"2021-10-21 02:49:49,625 - mmcls - INFO - Epoch [2][10/201]\tlr: 5.000e-04, eta: 0:00:55, time: 0.488, data_time: 0.237, memory: 1709, loss: 0.1764\n",
"2021-10-21 02:49:52,305 - mmcls - INFO - Epoch [2][20/201]\tlr: 5.000e-04, eta: 0:00:52, time: 0.270, data_time: 0.018, memory: 1709, loss: 0.1514\n",
"2021-10-21 02:49:55,060 - mmcls - INFO - Epoch [2][30/201]\tlr: 5.000e-04, eta: 0:00:49, time: 0.275, data_time: 0.016, memory: 1709, loss: 0.1395\n",
"2021-10-21 02:49:57,696 - mmcls - INFO - Epoch [2][40/201]\tlr: 5.000e-04, eta: 0:00:46, time: 0.262, data_time: 0.016, memory: 1709, loss: 0.1508\n",
"2021-10-21 02:50:00,430 - mmcls - INFO - Epoch [2][50/201]\tlr: 5.000e-04, eta: 0:00:43, time: 0.273, data_time: 0.018, memory: 1709, loss: 0.1771\n",
"2021-10-21 02:50:03,099 - mmcls - INFO - Epoch [2][60/201]\tlr: 5.000e-04, eta: 0:00:40, time: 0.268, data_time: 0.020, memory: 1709, loss: 0.1438\n",
"2021-10-21 02:50:05,745 - mmcls - INFO - Epoch [2][70/201]\tlr: 5.000e-04, eta: 0:00:37, time: 0.264, data_time: 0.018, memory: 1709, loss: 0.1321\n",
"2021-10-21 02:50:08,385 - mmcls - INFO - Epoch [2][80/201]\tlr: 5.000e-04, eta: 0:00:34, time: 0.264, data_time: 0.020, memory: 1709, loss: 0.1629\n",
"2021-10-21 02:50:11,025 - mmcls - INFO - Epoch [2][90/201]\tlr: 5.000e-04, eta: 0:00:31, time: 0.264, data_time: 0.019, memory: 1709, loss: 0.1574\n",
"2021-10-21 02:50:13,685 - mmcls - INFO - Epoch [2][100/201]\tlr: 5.000e-04, eta: 0:00:28, time: 0.266, data_time: 0.019, memory: 1709, loss: 0.1220\n",
"2021-10-21 02:50:16,329 - mmcls - INFO - Epoch [2][110/201]\tlr: 5.000e-04, eta: 0:00:25, time: 0.264, data_time: 0.021, memory: 1709, loss: 0.2550\n",
"2021-10-21 02:50:19,007 - mmcls - INFO - Epoch [2][120/201]\tlr: 5.000e-04, eta: 0:00:22, time: 0.268, data_time: 0.020, memory: 1709, loss: 0.1528\n",
"2021-10-21 02:50:21,750 - mmcls - INFO - Epoch [2][130/201]\tlr: 5.000e-04, eta: 0:00:20, time: 0.275, data_time: 0.021, memory: 1709, loss: 0.1223\n",
"2021-10-21 02:50:24,392 - mmcls - INFO - Epoch [2][140/201]\tlr: 5.000e-04, eta: 0:00:17, time: 0.264, data_time: 0.017, memory: 1709, loss: 0.1734\n",
"2021-10-21 02:50:27,049 - mmcls - INFO - Epoch [2][150/201]\tlr: 5.000e-04, eta: 0:00:14, time: 0.265, data_time: 0.020, memory: 1709, loss: 0.1527\n",
"2021-10-21 02:50:29,681 - mmcls - INFO - Epoch [2][160/201]\tlr: 5.000e-04, eta: 0:00:11, time: 0.265, data_time: 0.019, memory: 1709, loss: 0.1910\n",
"2021-10-21 02:50:32,318 - mmcls - INFO - Epoch [2][170/201]\tlr: 5.000e-04, eta: 0:00:08, time: 0.262, data_time: 0.017, memory: 1709, loss: 0.1922\n",
"2021-10-21 02:50:34,955 - mmcls - INFO - Epoch [2][180/201]\tlr: 5.000e-04, eta: 0:00:05, time: 0.264, data_time: 0.021, memory: 1709, loss: 0.1760\n",
"2021-10-21 02:50:37,681 - mmcls - INFO - Epoch [2][190/201]\tlr: 5.000e-04, eta: 0:00:03, time: 0.273, data_time: 0.019, memory: 1709, loss: 0.1739\n",
"2021-10-21 02:50:40,408 - mmcls - INFO - Epoch [2][200/201]\tlr: 5.000e-04, eta: 0:00:00, time: 0.272, data_time: 0.018, memory: 1709, loss: 0.1654\n",
"2021-10-21 02:50:40,443 - mmcls - INFO - Saving checkpoint at 2 epochs\n",
"[>>] 1601/1601, 170.9 task/s, elapsed: 9s, ETA: 0s2021-10-21 02:50:49,905 - mmcls - INFO - Epoch(val) [2][51]\taccuracy_top-1: 97.5016\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m_ZSkwB5Rflb"
},
"source": [
"### Test a model\n",
"\n",
"We use `tools/test.py` to test a model:\n",
"\n",
"```\n",
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]\n",
"```\n",
"\n",
"Here are some optional arguments:\n",
"\n",
"- `--metrics`: The evaluation metrics. The available choices are defined in the dataset class. Usually, you can specify \"accuracy\" to metric a single-label classification task.\n",
"- `--metric-options`: The extra options passed to metrics. For example, by specifying \"topk=1\", the \"accuracy\" metric will calculate top-1 accuracy.\n",
"\n",
"More details are in the help docs of `tools/test.py`.\n",
"\n",
"Here we still use the `MobileNetV2` model we fine-tuned."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zd4EM00QRtyc",
"outputId": "8788264f-83df-4419-9748-822c20538aa7"
},
"source": [
"!python tools/test.py configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py work_dirs/mobilenet_v2_1x_cats_dogs/latest.pth --metrics accuracy --metric-options topk=1"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"Use load_from_local loader\n",
"[>>] 2023/2023, 168.4 task/s, elapsed: 12s, ETA: 0s\n",
"accuracy : 97.38\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IwThQkjaRwF7"
},
"source": [
"### Inference with a model\n",
"\n",
"Sometimes we want to save the inference results on a dataset, just use the command below.\n",
"\n",
"```shell\n",
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}]\n",
"```\n",
"\n",
"Arguments\n",
"\n",
"- `--out`: The output filename. If not specified, the inference results won't be saved. It supports json, pkl and yml.\n",
"- `--out-items`: What items will be saved. You can choose some of \"class_scores\", \"pred_score\", \"pred_label\" and \"pred_class\", or use \"all\" to select all of them.\n",
"\n",
"These items mean:\n",
"- `class_scores`: The score of every class for each sample.\n",
"- `pred_score`: The score of predict class for each sample.\n",
"- `pred_label`: The label of predict class for each sample. It will read the label string of each class from the model, if the label strings are not saved, it will use ImageNet labels.\n",
"- `pred_class`: The id of predict class for each sample. It's a group of integers. \n",
"- `all`: Save all items above.\n",
"- `none`: Don't save any items above. Because the output file will save the metric besides inference results. If you want to save only metrics, you can use this option to reduce the output file size.\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6GVKloPHR0Fn",
"outputId": "4f4cd414-1be6-4e17-985f-6449b8a3d9e8"
},
"source": [
"!python tools/test.py configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py work_dirs/mobilenet_v2_1x_cats_dogs/latest.pth --out results.json --out-items all"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"Use load_from_local loader\n",
"[>>] 2023/2023, 170.6 task/s, elapsed: 12s, ETA: 0s\n",
"dumping results to results.json\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G0NJI1s6e3FD"
},
"source": [
"All inference results are saved in the output json file, and you can read it."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 370
},
"id": "HJdJeLUafFhX",
"outputId": "7614d546-7c2f-4bfd-ce63-4c2a5228620f"
},
"source": [
"import json\n",
"\n",
"with open(\"./results.json\", 'r') as f:\n",
" results = json.load(f)\n",
"\n",
"# Show the inference result of the first image.\n",
"print('class_scores:', results['class_scores'][0])\n",
"print('pred_class:', results['pred_class'][0])\n",
"print('pred_label:', results['pred_label'][0])\n",
"print('pred_score:', results['pred_score'][0])\n",
"Image.open('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"class_scores: [1.0, 5.184615757547473e-13]\n",
"pred_class: cats\n",
"pred_label: 0\n",
"pred_score: 1.0\n"
]
},
{
"output_type": "execute_result",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEYCAIAAABp9FyZAAEAAElEQVR4nJT96ZMkWXIfCKrqe88uv+KOvCuzju7qrmo00A00QZAcoZAznPmwIkP+p/NhZWVnZTAHSBALAmig+u6ursrKMzIyDj/teofqflB3S89qckTWpCTKM8Lc3ezZ0+unP1XF737+KQAgGEQDQCAoIgD04MGDul7f3Nx0fTsalbPZJCa/XC5dOQ4hGGMAoGka55y1tu/7LMu89865zWbz+PHj+Xy+2WxGo1G3rokIAJgZEa3N9L3C4Jwzxngfu64TkbIsq6pqu40xJqXU922MkYiMMSKS56Uh61zubGmtM5QjGkTMDZRl2fd9CKGoSkTx3gthSgkRBCGllFIKnAAAEbGXPM/1q7vOM3OM0ZBLKcWYUkrM/PHH33n58mWe5yklnlW3t7e5M6fHJ7PRaLOY3769zAxNsvz06PDJwwe5ofnV1eNHD2+uru7du3O7aeu6Nha97/7n//l/fvHimQ8dIi4Wi/V6vVk34/H0+voa0Xzy8XdfvXrVtt39+/c//vjjt2/fPn32TZ7nDBKZY4x37997+/btn/zpj9++fftX//E//vCHP1yv1yWS98FaK2i+//3vv3nztutDluWcQMjEwNfXN/cfPHrz5o0IVlVlQvfpp5+mlH72s5+NZ9MY48nZ6Xy5XK1Wn37/e3fu3Hnx6uVvf/vbPoQ7d+6cn5/317ezg8lqteq6phqVV1eXm83q0aNHTdOklFKSlASBrM2yLHM2ny9eiAgB5nleFBWhjZG7kMAUybibTXdxs1yHaMspOtf7aAiMMdZaRERhRHTOOUNd13FKIfQxRokJUay1xpibxTzGKCLWWiJCkizL8jy/vb1JKT169Kiu133fHx8f397etm07nU67rqubdex9nueTycQY5JjeXF4cTmfHx8fG4Pz2drVaTSaTO3furNb1YrHw3mdZllKq6xoRx+NxWZbM7JyLMW42G+dcVVUppc1mlee5916/6/T0tGmaZ8+e3bt3DxHLsgSAvu+dcyIyn8/J8PHxIVFs13Mi/z/+9/9qOrWzcX7n7lEKTQi9QLKICAAg8K2j7/uUEiKKSAih67rEIYRQ94uiKEII0+n06urq5OQkhAAAMUa9iJQSAFRVZa1dr9ciwszGGN36RDal5L1HIGYGABHRaxARXe5vXwrAVoaBmVl2B4CIiE/RGOO9DyGQNYMQhhCICA3q240xiGj0f4jW2rqui6Lquq4sS9/HoiiOjo5fvXo1mUx+/OMf379//2/+5m8ePnz4y9fPfvSjH/3t3/y1M/aPP//8F/+0qKpqeXuTCSBi0zTZeDSbzbquy/P86uqqmB5WVeUyU9dS1/XV1VVMvqqqs7Oz5XK5XC6rajydTlerze3tbUppPB4bY4qi0KUmIu9749xqtdKlOz8/J6I/+7M/CyE450Lb6foj4dOnT7OsCCGEEIu8stYSwsHBASBba5m5qoo337yq6/WdO3em0/Hd+3frun79+uXZ2Vni8PTpVwcH0zx3JydHTdN88+zrsspzgL7vjTHT6bQo87atQ+ibpgkhWGuzLBPBGFIIqWkakM7lGcfEzMzsvUdIPiYfOJ9UQGSMIWsNCxElkARCAqr4AACFRcR7j8LMzCnF6FNKFinLbFEUWZYdnZ40TdN1XQghhOBD1/e9bjYVifV6HUIYjUZElBeubVuBlGWZI6OLEH3o+x4AyrIUkfl8EUI4Pj4moouLCxa01qrCreu6russy0RkPB7Xda17O8syIkopxRiZue97773esu7YsixVjJ1zRDRsY2stS4wxWgPWWmt1JwMiJh9i4pQSS6Rhl8v7h/ceEVVy9r9YHwYijkYjABiNRnoPw14XEbWBjx8/zvNcf6NXP0ia3ttwV2ru9CnuiyUi6tbUW8L3D/3GuDtCivsHMydhERARRgAAIkLEBNLH4FNcbtYhJUHMigKtuV0uNm2zbuou+Murq+9//rnNsuvb2zzPz8/PP//88+l0aoxxzumNq6LRVTo6Ouq67ujoaL1e3717dzwep5SyLBuNRrPZjJlXq9VqtXLO6T7I81y1qd71crns+17dAedc13UoYoy5vLw8OTl5/vz5D37wg/Ozs6PDwzzPmVkfZmbp5cuXKSUU0MtAxLars8xJ4uOTQ/VN7t69+0//9E+67Z49e6ZKJ8Z4cHDw4ZMnP//5z5tNrdvx0f0Hq/miqde31zd92xhCjoEADdJ6ueqaNoVokMo8K/LMEkYfurZGREEQER9DCMHHwMyM0HVt27ad72OMMaXOt13X6T3GGL33IXh1Urz3XddlWWYt7R7rdqsw87Nnz54/f/7ixYu3b99u6lVKyVqbZS7LsizLdFPpNiADRJQ4OOemo/F0Oi2KAoD188fjMQC0bRtDyLJMDd1yuby+vo4xZllmjCGi8Xh8dnZ2584dAPDe13UdY9RtrIrDWqsbTLdTCAERp9Op6ik9Qfe8un4JpO9DCMHaLM9LvVRh8N7HGEUQwZCezcxpdwwibq0djUZlWaoa0C/QbeScCyGMx2MRybJMnQoR6ft+vV5fXV0h4ocffnh6ejoajVSRqIx1Xafr+L6PalWwU0qDJRzEbBBLRCQilVh98U4g3700QoiI+pCstcZYEBSGFDlF7r0XgN57Y20fvMszm7nRZFxUZe/94dFRluf/5e//7uPvfPKTP/9n63pzfHz8/Pnzf/2v//WDBw9+9atf6dqpsiyKIqWkznnTNKenp9PpVNVn0zSbzeby8lK90BDCb37zG2vt2dlZ3/d1XRdFoU+rbdurq6vlcpmEp9Pp0dGRcw4AiqLo2857//vf/q5erfURTMbjsiy7rh3WIYTgMiPCSGJAdspbPnz85OT0qOub2XTcdvXzF98kDin0PnQI7H2XGTo7OyGU2/n1Yn5jCI4OZ02z6ft+uVyq8VksFl3XuMxYR71v62a92azUThZFUZalc86nkFIMKaaUYkqIaDNXlCUQCoIxJitcVhY2z4xzLs8QkZm93+p0IrKWdD/oHTnnNMbRM+/cuXN0dKSO1aCRdds4ZxCxKAqXGR+6EEJd11VejMtK/UlriQCNMS6zR7ODvu/bth6NRqPRqGmaul7nuauqSpcxxjgej+/fv3/37t3pdHp9fb1er9u2TSk553RzWmuLotDtp0ZPhVDV7mCo9F7UzAij97Hvg6AGVsAhMbP3UZIYpNwVpHel4qcirnKoFl9XRKVFhTDP89VqhYjz+Xwymcznc0TUB6OGmJnrum6aRgVYTZBzTkVCpUtvYJCbQa50lQch3JdD3h3f0hr6sIczhRCAEI2xmTHGGAN7QktEYMgVeQKZHh4kEJtnPsWsKNQe5lXJCJfXV6t644ocDLVtu1gs1Dm/vr6u63qz2YiIOkjOOdU7RHRwcPDJJ58Q0ZMnTz755JOyLL/88stvvvkGAL773e/OZjO9AFVn6l8YY5yzIty2zXg8un//3tnZ6cHBjAhR+PDo4JunX2eZ+9//978cFcXLZ8/KLJtOp23b6uecHh2v18uUUgpxvV7H5HNrEofet5PJ+N6dO9Px+M2b13/ywz/65puvr64u//zPf3JyeHB5eVGV+Xqz/Onf/5c//qPPLSGBfP797/3iZ19Yg9aStSTAXdfU9TrGOBqNzs/PsyyLMa5Wq+Vy7n3nMluWeVnmiCiEYBAIhRhIyFrnXDmqqqoqRuVkNj04OJhOp7OD6eHxgbWGCNWl6vs+hF631mJ5u9lsvPcAbK3VHQIAR0dHBwcHk+moKDPdIQCit8/MbdsmDjHGuq5VIZZlqVYhRh99UIMxLisAcM6URWGtTSl436FAmRcff/zxeDxWpVMUxWQySSldXl7e3Nzoe3UXqQEnIjUY+vtB8FRpMrPamGG3ExGA+rEphMBJhJETxMgxcIoCQIho33dHh+BQ+r4vikz3SkoJgMkYRHRZpjAMEZ2cnCyXyyzLxuNxjHEw6HVdLxYL9bJ0rVXwjDEpSQhBRPKsGELBwdskIpb0LXdUzwkhgKAxKaWESAxRBEXEUhRJzFFEEggyi0hiF
"text/plain": [
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x280 at 0x7EFEADA54C90>"
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1bEUwwzcVG8o"
},
"source": [
"You can also use the visualization API provided by MMClassification to show the inference result."
]
},
{
"cell_type": "code",
"metadata": {
"id": "BcSNyvAWRx20",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 304
},
"outputId": "0d68077f-2ec8-4f3d-8aaa-18d4021ca77b"
},
"source": [
"from mmcls.core.visualization import imshow_infos\n",
"\n",
"filepath = 'data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg'\n",
"\n",
"result = {\n",
" 'pred_class': results['pred_class'][0],\n",
" 'pred_label': results['pred_label'][0],\n",
" 'pred_score': results['pred_score'][0],\n",
"}\n",
"\n",
"img = imshow_infos(filepath, result)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATMAAAEfCAYAAAAtNiETAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAFiQAABYkBbWid+gAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9d7Rk2VXm+TvmmjDPv8yXmZWmSt4hUDWSCiEE3QgBYnCSMMPMYgYaqUFIgxHQTTdDQ4uhEaIb0EA3vqGBxkiAJOQwLYMTKpUk5MtlVnr/TNhrjps/zr03IrMk/pq1ZvWsilqxXuWLeBFx7z3n29/+9rd3iBACj90euz12e+z2P/pN/n/9AR67PXZ77PbY7f+N22Ng9tjtsdtjt/9f3PTtv3jK5zw1AAgkIBBCA4LgBSDY3t6mrktGoxHzYkKapgyHfXywjCcTnIMkzwkhsLKywuXLl9na2sJai/ceIQQAw+GQ2WzG4cOHqeuauq6ZTCaE2gKglEJrTZIkSCnxHuq6BgRaaaTUOOeo65oQAmmakmUJxhiUlnhvqaoKay1CCJRS8YB1ipKaJMlIkz5aJyiVIESCICAD9PKUuq4xxpDlGUJKjKkJImCa1xNKIIQghEAIAaRASwVeIIMgTVPKsiRJsviZpMZah1KKtbU19vb2SZKE5zznHqy1fPjDHyYf9Lk22uf4Xac498gjiOD4J8+6m73rV9m/cYPx/h49pTh2aJudQ9tsrqwggUwr8A5nHVmvjweEEtR1yZOe9ESqqmQ02iNJE1ZXV9nd3eXChUvkWU6a5sznJVmax+scYHV1lVOnThFC4Oy5sxRlCQICUFQVw5UVhBTc+bi7OHT4MBcvXWJ/f5+qqqiLEhVACIEPgTTNyLIeztMcf0Ka9ZBSU9cWhMA6z2w6J0sznKnItOSJj388ZVlyMB7RHwzY39+ntoY0y5gVczY3N9k5eoTdvT12d3cZjUYIIbhj5ygJgkQLQnDoROO9Y3f3BlVVkmUZ3nu0jmvI+4A1Hmsdcd17pHQ4bxEBtNYoleCDpDYOJxRpb5U6SHYnBdfHEwoLMuvhUBjrUBKklEgpIQQIHojMIQSP9x7nLM4avPckUpGmKXmekmUZOkuZz+dUVYUxBmOrbp1LKanruK43NtaYTqd479ne3qKsirjPHIQQcD6+fqo0Wmu8NVRVybyYs7OzQz/LmUwmWFPR7/cRQlAUJYG4tqWUOOcYjUZMJnGvb2xscOjQIWazGd57QghYa5FSopTCWktRzBBCUNc1zjlWV1fZ3Nykqipu3rzJYDAgTVPW1tYQQjCfz7u9NJlMCDjW1gZoGXBmgtaOf/qC5zIcajbWemxvDLGuwNqaEBzP+6L/XXxGMGvBhkZKazW1VlprASXLMowtCSFQVRUBh/ceYzy9ocZay2AwAGAwGDCbzQCw1hJCQClFCIHpdMr6+jpHjx7loYcewriAcw7nXFwMzWew1mGMIdEpzjm8p7u4IQS899R13Z2U9jiEWIDO8jE+6t485qzFWolxFuMs0iqkjP8OxPdBClSI4OjjCyKlbDYw1M6Ck0zmMwb9eAxZluFFzWg8RiaayXxGr9djd3+Ppz/9GTz48MPsjvYRQrC1tUWepuzv3iBNE5IkodfrMRvH98zynOAjiK6urlIWM4b9AXt7e2yvriCkYjIdo5Si1+uRZSnT2Zj5fN6dzzzLkFKidUKWBYJfnLP5fN4tMGstSilqYwgEpBBMp1PW1tcoy5LhcMhdd96JkpLxaMRBVSFCQIkI6tPJhKKoGA5XcdYhUEgEzroGVBJ6WYIUgro25L2cYjLmwoUL3HHHHXjnuXHjBkqrLsDt7OxgnePc2XNkWUY5LzDGsLmxwXg8IhUCLQS9Xk6aJHgfmve0zEyFlArVG5AkGpkoaizeFpRljZCerK8IPuBDwBuD8gCKEMAHT1GWVEFQmQpjLcZBqCtsEHjnydIU71231mRzzoOLgbDXy6mNx7t21XkgAkNVV1y6eoWiiMckpSRJ47EnSYLWGojXynvfrWlrXQeg1tZkWUYis2Zv2C7wV1XNSn+ACIGimGFt3QBpznw+ZTQ6oDaOjY0NhsMhdV0jpWQ4HLK+vs7m5ma316qqIssy0jRtANp1RKQFsnZfGGMQQrCyssJ0Ou3WWgwqETTbfet9xBQSQaJTsqw5RhOfU9c1wTvirl0kl48Cs45pBEkI7ckSECJTs9aS5ymDwQDn4wE55xGyjWIRXLTWGGMYDocdc7LWkiQJVVVRVRWTyYT5fM7GxgZ33XUXo9GIgxu7lGXZMbkInhFMlVJLAOeBBeMC8N7FxdP8rgPmpWNrwau98IvXlBA8QQr4TGDXnvw06S4QgG8ugrOeIMEaixSS2pjI6Jwlz3N0mqASjXEW6x3D1RUSnfCp++/n+S94AXd//j/h3e99D8N+j93dPe7+vGdy/eo6ly9epJiM8dailUJ4T5okzQI2IAR1VTPcPhQXlHW4YOOGq+bs7u4SgmM8HuO9Y29vj8FgwPr6OrPZnKKYo1SCThOCD5RlyWQyYXV1lcFgQJbnZFnGdD6jKAq0lJGBlRXXr1xla2OTwUq8xkmSMOwPmE+nuIb9xKDlESIQL51HyICpasrK0O9LNjc3SdKUy1cug3NoLdnb36XXy7HOUJuKlJQABO9YGQxwwTMaHTCfz5gXc3pZRp6mjPf3cfGJESBwVFWJdYYkUdS1oywLvI8MKc8H6ESThx4gMK7EOYN1EWwVEiEUKkki+KGoHAgEWmuyLCM4gZMaEcBKByICZ7uZU61RUoLQIEAqgXISrxRSRuBVSnXrcmdnp2NDrgHAdl+2azxJIgjkeR6vh6m7bGQwGMQgphOstVRVPF6lFGmWsLa2xmw2IwTHYDAg0boLYFmWopP4eYwxOOdYWVlhMBiwurpKmqacPn2a+XyOcw6tNf1+vwOvNE2RMu8ISZtZtXt/OBwyGo3iZ67rbv+FsCAxgUBdW4IXpAONUqIjNMGBqS1SBbRSKLWAsM8KZu0mbQmNQHVgJkTWRYqqquJibVBWiITRZMLa2hoHBwcMh8PuZ1VV5HneHZz3nqqqmM/n3QU4IFL09iS07ADiBRTIJZa1YEQtO1u+LYPQAvA8XviGIvvmbzwE1x0HYgGEQUCQoqHeEq2S+JiMaZRCxtcSIKQEFdBJgnOOweoKdW1RaYINnizNSPM8nngZmeneaJ+DyRihFVIpSlMhxiOKIqYMBwcH1PMpwVisszhjqU3NIM+x1jGbxTRjOByysbHBZDYn7/c5tLPNlSuXuH79OpPJGKngcY+7i8lk0kXbeM4Caaqa4xbNNYTaVKwmqxw6tE2/32f/YJ8bN13cLP0e0/mU2qacPv0wOzs7zCeTZiP1qeZzqqpGa00vyzHGUhUlQch4PlSCQKBEwJoKpQRbmxuYquTalUv084wsUVy6fIGTJ09y8uRxTj9yhqKYk/dTLl44x9raGiePH+Pc+fNsbazhvefypYusrwzRSLwD6xxFUVDXJUoJhmtrWGu4efNmTOFqh7WB4XCFNEsQEoQJGFeAirJKQBCEQEiB1BqhEqRIkEFiE4dNcxIvsELhEFjrqOZzvBNxc1pLcCayKqnw3nNwcAAhIAgkOjLvJNEICUpJ1jc345rGdYE9rscYjCOLcZSlBwHW1bh5fA9rLYO81wS8gLU1tjYIH8gSTapjoE9TjRQJWim8NZiqRCLo93sc3rmD8XjM/v4+3nvW19cZDodYa9nf32dvb488z0mSBKVUl2VFopMjZUxr2wzMOdfgRCBvgqP3nrKMaX/7Oi0TNQ3LE0BlIEt1lLlczOysBR1CvCZiQWY+I5jdCmyRAbV5ZwSklCRJlqiljxtZRK2oqirSNKaD29vbjMdj0jS9JW1pqetsNuPg4KBDa2NMpJhAv99HKdWkrzHNjNpO+/lu/bxRW/OPSjOXj8cYA0GgVIwCQsSUleZYlaADxkU0XPx/e8KFXkRSKSVSLyJMaWsEkOiMXGlKU0NVUzvLaDJGa01d1VjnEErxkY99lBvXrjMrCuZ1ybHjx
"text/plain": [
"<Figure size 300.01x280.01 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
}
]
}