mmclassification/docs/zh_CN/tutorials/MMClassification_tools_cn.ipynb

1248 lines
639 KiB
Plaintext
Raw Normal View History

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "MMClassification_tools_cn.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "XjQxmm04iTx4",
"tags": []
},
"source": [
"<a href=\"https://colab.research.google.com/github/open-mmlab/mmclassification/blob/master/docs_zh-CN/tutorials/MMClassification_tools_cn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4z0JDgisPRr-"
},
"source": [
"# MMClassification 命令行工具教程\n",
"\n",
"在本教程中会介绍如下内容:\n",
"\n",
"* 如何安装 MMClassification\n",
"* 准备数据\n",
"* 准备配置文件\n",
"* 使用 shell 命令进行模型训练和测试"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "inm7Ciy5PXrU"
},
"source": [
"## 安装 MMClassification\n",
"\n",
"在使用 MMClassification 之前,我们需要配置环境,步骤如下:\n",
"\n",
"- 安装 Python, CUDA, C/C++ compiler 和 git\n",
"- 安装 PyTorch (CUDA 版)\n",
"- 安装 mmcv\n",
"- 克隆 mmcls github 代码库然后安装\n",
"\n",
"因为我们在 Google Colab 进行实验Colab 已经帮我们完成了基本的配置,我们可以直接跳过前面两个步骤 。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDOxbcDvPbNk"
},
"source": [
"### 检查环境"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c6MbAw10iUJI",
"outputId": "5f95ad09-7b96-4d27-dfa8-17f31caba50d"
},
"source": [
"%cd /content"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4IyFL3MaiYRu",
"outputId": "b0ab6848-12ea-49a1-98ec-691e2c9814e1"
},
"source": [
"!pwd"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DMw7QwvpiiUO",
"outputId": "d699b9d2-22e5-431c-83d8-9317a694cb0e"
},
"source": [
"# 检查 nvcc 版本\n",
"!nvcc -V"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"nvcc: NVIDIA (R) Cuda compiler driver\n",
"Copyright (c) 2005-2020 NVIDIA Corporation\n",
"Built on Mon_Oct_12_20:09:46_PDT_2020\n",
"Cuda compilation tools, release 11.1, V11.1.105\n",
"Build cuda_11.1.TC455_06.29190527_0\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4VIBU7Fain4D",
"outputId": "7eb1d91f-86c7-43cf-d335-3d37ae014060"
},
"source": [
"# 检查 GCC 版本\n",
"!gcc --version"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
"Copyright (C) 2017 Free Software Foundation, Inc.\n",
"This is free software; see the source for copying conditions. There is NO\n",
"warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "24lDLCqFisZ9",
"outputId": "3c553c42-e7ac-4c6a-863e-13ad158bac22"
},
"source": [
"# 检查 PyTorch 的安装情况\n",
"import torch, torchvision\n",
"print(torch.__version__)\n",
"print(torch.cuda.is_available())"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1.9.0+cu111\n",
"True\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2aZNLUwizBs"
},
"source": [
"### 安装 MMCV\n",
"\n",
"MMCV 是 OpenMMLab 代码库的基础库。Linux 环境的安装 whl 包已经提前打包好,大家可以直接下载安装。\n",
"\n",
"需要注意 PyTorch 和 CUDA 版本,确保能够正常安装。\n",
"\n",
"在前面的步骤中,我们输出了环境中 CUDA 和 PyTorch 的版本,分别是 11.1 和 1.9.0,我们需要选择相应的 MMCV 版本。\n",
"\n",
"另外,也可以安装完整版的 MMCV-full它包含所有的特性以及丰富的开箱即用的 CUDA 算子。需要注意的是完整版本可能需要更长时间来编译。"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nla40LrLi7oo",
"outputId": "475dcd11-0b58-45d3-ad81-a3b7772d3132"
},
"source": [
"# 安装 mmcv\n",
"!pip install mmcv -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html\n",
"# !pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in links: https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html\n",
"Collecting mmcv\n",
" Downloading mmcv-1.3.15.tar.gz (352 kB)\n",
"\u001b[K |████████████████████████████████| 352 kB 5.2 MB/s \n",
"\u001b[?25hCollecting addict\n",
" Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcv) (1.19.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from mmcv) (21.0)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from mmcv) (7.1.2)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from mmcv) (3.13)\n",
"Collecting yapf\n",
" Downloading yapf-0.31.0-py2.py3-none-any.whl (185 kB)\n",
"\u001b[K |████████████████████████████████| 185 kB 45.4 MB/s \n",
"\u001b[?25hRequirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->mmcv) (2.4.7)\n",
"Building wheels for collected packages: mmcv\n",
" Building wheel for mmcv (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for mmcv: filename=mmcv-1.3.15-py2.py3-none-any.whl size=509835 sha256=0296cfd1e3e858ece30623050be2953941a442daf0575389030aa25603e5c205\n",
" Stored in directory: /root/.cache/pip/wheels/b2/f4/4e/8f6d2dd2bef6b7eb8c89aa0e5d61acd7bff60aaf3d4d4b29b0\n",
"Successfully built mmcv\n",
"Installing collected packages: yapf, addict, mmcv\n",
"Successfully installed addict-2.4.0 mmcv-1.3.15 yapf-0.31.0\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GDTUrYvXjlRb"
},
"source": [
"### 克隆并安装 MMClassification\n",
"\n",
"接着,我们从 github 上克隆下 mmcls 最新代码库并进行安装。"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bwme6tWHjl5s",
"outputId": "07c0ca6f-8a10-4ac3-a6bc-afabff6aba51"
},
"source": [
"# 下载 mmcls 代码库\n",
"!git clone https://github.com/open-mmlab/mmclassification.git\n",
"%cd mmclassification/\n",
"\n",
"# 从源码安装 MMClassification\n",
"!pip install -e . "
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'mmclassification'...\n",
"remote: Enumerating objects: 4152, done.\u001b[K\n",
"remote: Counting objects: 100% (994/994), done.\u001b[K\n",
"remote: Compressing objects: 100% (574/574), done.\u001b[K\n",
"remote: Total 4152 (delta 476), reused 764 (delta 403), pack-reused 3158\u001b[K\n",
"Receiving objects: 100% (4152/4152), 8.20 MiB | 20.90 MiB/s, done.\n",
"Resolving deltas: 100% (2525/2525), done.\n",
"/content/mmclassification\n",
"Obtaining file:///content/mmclassification\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (3.2.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (1.19.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (21.0)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (2.8.2)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (1.3.2)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (0.10.0)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (2.4.7)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->mmcls==0.16.0) (1.15.0)\n",
"Installing collected packages: mmcls\n",
" Running setup.py develop for mmcls\n",
"Successfully installed mmcls-0.16.0\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hFg_oSG4j3zB",
"outputId": "521a6a75-2dbb-4ff2-ab9f-4fbe785b4400"
},
"source": [
"# 检查 MMClassification 的安装情况\n",
"import mmcls\n",
"print(mmcls.__version__)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.16.0\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arpM46CZOPtR"
},
"source": [
"## 准备数据"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XHCHnKb_Qd3P",
"outputId": "4f6eaa3f-7b96-46e4-e75b-aae28c8ec42d"
},
"source": [
"# 下载分类数据集文件 (猫狗数据集)\n",
"!wget https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0 -O cats_dogs_dataset.zip\n",
"!mkdir data\n",
"!unzip -q cats_dogs_dataset.zip -d ./data/"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2021-10-21 02:53:27-- https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0\n",
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312\n",
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.\n",
"HTTP request sent, awaiting response... 301 Moved Permanently\n",
"Location: /s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip [following]\n",
"--2021-10-21 02:53:27-- https://www.dropbox.com/s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip\n",
"Reusing existing connection to www.dropbox.com:443.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://uc2e142222b11f678e96f89b0223.dl.dropboxusercontent.com/cd/0/inline/BYaBa5-WWfPf_jhSt9A5JMet_BB55MzZhB2D3RXLo53VGHSIYbVMnFTdccihcsD-kwc9FxBG8qOwqA50z7XD6-3yUXWK9iA0x4L8IV5wegYKilKuDauDKWiNAsbgZoEBg4nC1UWR5pLSiH3j0Dn68b2V/file# [following]\n",
"--2021-10-21 02:53:27-- https://uc2e142222b11f678e96f89b0223.dl.dropboxusercontent.com/cd/0/inline/BYaBa5-WWfPf_jhSt9A5JMet_BB55MzZhB2D3RXLo53VGHSIYbVMnFTdccihcsD-kwc9FxBG8qOwqA50z7XD6-3yUXWK9iA0x4L8IV5wegYKilKuDauDKWiNAsbgZoEBg4nC1UWR5pLSiH3j0Dn68b2V/file\n",
"Resolving uc2e142222b11f678e96f89b0223.dl.dropboxusercontent.com (uc2e142222b11f678e96f89b0223.dl.dropboxusercontent.com)... 162.125.3.15, 2620:100:6018:15::a27d:30f\n",
"Connecting to uc2e142222b11f678e96f89b0223.dl.dropboxusercontent.com (uc2e142222b11f678e96f89b0223.dl.dropboxusercontent.com)|162.125.3.15|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: /cd/0/inline2/BYZCXE2D0HPaLzwKVyTyfirCsVVcpsp0-D9eMfo9OFpQdWubKX08yUdUJz2CZ7dn6Vm4ZF22V2hf_4XTw41KZRj5m3Dm_1Z8gH9h_kawyi4bsKn5EYJ6b89lfhXhoxgBa0Fa8h7V39gPRaIfaWDiUE0tzYAM_aEVwT30FVU4uWisNXBvjz5-yS6_XYzJIiMZ1CUrFU8DwqBis4RwPmLA7rzdCsVV7a6VV0NiTcNgOKMwLP0lMYx4bYpDDmnOtF-m-GBVArV_2Xd0akIDKSXy4LY-4ovbTNI13uvUX5U3UcjpR0UPjGtBcgm3LR4Iqcvw5D6Wt14g3PCmBMIPgdTp_IN9RnLl9AK_mfl4v1kmJ_C-BfoEr43qQP-6uqBavD3Xhz8/file [following]\n",
"--2021-10-21 02:53:27-- https://uc2e142222b11f678e96f89b0223.dl.dropboxusercontent.com/cd/0/inline2/BYZCXE2D0HPaLzwKVyTyfirCsVVcpsp0-D9eMfo9OFpQdWubKX08yUdUJz2CZ7dn6Vm4ZF22V2hf_4XTw41KZRj5m3Dm_1Z8gH9h_kawyi4bsKn5EYJ6b89lfhXhoxgBa0Fa8h7V39gPRaIfaWDiUE0tzYAM_aEVwT30FVU4uWisNXBvjz5-yS6_XYzJIiMZ1CUrFU8DwqBis4RwPmLA7rzdCsVV7a6VV0NiTcNgOKMwLP0lMYx4bYpDDmnOtF-m-GBVArV_2Xd0akIDKSXy4LY-4ovbTNI13uvUX5U3UcjpR0UPjGtBcgm3LR4Iqcvw5D6Wt14g3PCmBMIPgdTp_IN9RnLl9AK_mfl4v1kmJ_C-BfoEr43qQP-6uqBavD3Xhz8/file\n",
"Reusing existing connection to uc2e142222b11f678e96f89b0223.dl.dropboxusercontent.com:443.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 228802825 (218M) [application/zip]\n",
"Saving to: cats_dogs_dataset.zip\n",
"\n",
"cats_dogs_dataset.z 100%[===================>] 218.20M 73.2MB/s in 3.0s \n",
"\n",
"2021-10-21 02:53:31 (73.2 MB/s) - cats_dogs_dataset.zip saved [228802825/228802825]\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e4t2P2aTQokX"
},
"source": [
"完成下载和解压之后, \"Cats and Dogs Dataset\" 文件夹下的文件结构如下:\n",
"```\n",
"data/cats_dogs_dataset\n",
"├── classes.txt\n",
"├── test.txt\n",
"├── val.txt\n",
"├── training_set\n",
"│ ├── training_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.1.jpg\n",
"│ │ │ ├── cat.2.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.2.jpg\n",
"│ │ │ ├── dog.3.jpg\n",
"│ │ │ ├── ...\n",
"├── val_set\n",
"│ ├── val_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.3.jpg\n",
"│ │ │ ├── cat.5.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.1.jpg\n",
"│ │ │ ├── dog.6.jpg\n",
"│ │ │ ├── ...\n",
"├── test_set\n",
"│ ├── test_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.4001.jpg\n",
"│ │ │ ├── cat.4002.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.4001.jpg\n",
"│ │ │ ├── dog.4002.jpg\n",
"│ │ │ ├── ...\n",
"```\n",
"\n",
"可以通过 shell 命令 `tree data/cats_dogs_dataset` 查看文件结构。"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 297
},
"id": "46tyHTdtQy_Z",
"outputId": "a6e89ddb-431e-4ba0-f1f5-3581a702fd2a"
},
"source": [
"# 获取一张图像可视化\n",
"from PIL import Image\n",
"Image.open('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')"
],
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEYCAIAAABp9FyZAAEAAElEQVR4nJT96ZMkWXIfCKrqe88uv+KOvCuzju7qrmo00A00QZAcoZAznPmwIkP+p/NhZWVnZTAHSBALAmig+u6ursrKMzIyDj/teofqflB3S89qckTWpCTKM8Lc3ezZ0+unP1XF737+KQAgGEQDQCAoIgD04MGDul7f3Nx0fTsalbPZJCa/XC5dOQ4hGGMAoGka55y1tu/7LMu89865zWbz+PHj+Xy+2WxGo1G3rokIAJgZEa3N9L3C4Jwzxngfu64TkbIsq6pqu40xJqXU922MkYiMMSKS56Uh61zubGmtM5QjGkTMDZRl2fd9CKGoSkTx3gthSgkRBCGllFIKnAAAEbGXPM/1q7vOM3OM0ZBLKcWYUkrM/PHH33n58mWe5yklnlW3t7e5M6fHJ7PRaLOY3769zAxNsvz06PDJwwe5ofnV1eNHD2+uru7du3O7aeu6Nha97/7n//l/fvHimQ8dIi4Wi/V6vVk34/H0+voa0Xzy8XdfvXrVtt39+/c//vjjt2/fPn32TZ7nDBKZY4x37997+/btn/zpj9++fftX//E//vCHP1yv1yWS98FaK2i+//3vv3nztutDluWcQMjEwNfXN/cfPHrz5o0IVlVlQvfpp5+mlH72s5+NZ9MY48nZ6Xy5XK1Wn37/e3fu3Hnx6uVvf/vbPoQ7d+6cn5/317ezg8lqteq6phqVV1eXm83q0aNHTdOklFKSlASBrM2yLHM2ny9eiAgB5nleFBWhjZG7kMAUybibTXdxs1yHaMspOtf7aAiMMdZaRERhRHTOOUNd13FKIfQxRokJUay1xpibxTzGKCLWWiJCkizL8jy/vb1JKT169Kiu133fHx8f397etm07nU67rqubdex9nueTycQY5JjeXF4cTmfHx8fG4Pz2drVaTSaTO3furNb1YrHw3mdZllKq6xoRx+NxWZbM7JyLMW42G+dcVVUppc1mlee5916/6/T0tGmaZ8+e3bt3DxHLsgSAvu+dcyIyn8/J8PHxIVFs13Mi/z/+9/9qOrWzcX7n7lEKTQi9QLKICAAg8K2j7/uUEiKKSAih67rEIYRQ94uiKEII0+n06urq5OQkhAAAMUa9iJQSAFRVZa1dr9ciwszGGN36RDal5L1HIGYGABHRaxARXe5vXwrAVoaBmVl2B4CIiE/RGOO9DyGQNYMQhhCICA3q240xiGj0f4jW2rqui6Lquq4sS9/HoiiOjo5fvXo1mUx+/OMf379//2/+5m8ePnz4y9fPfvSjH/3t3/y1M/aPP//8F/+0qKpqeXuTCSBi0zTZeDSbzbquy/P86uqqmB5WVeUyU9dS1/XV1VVMvqqqs7Oz5XK5XC6rajydTlerze3tbUppPB4bY4qi0KUmIu9749xqtdKlOz8/J6I/+7M/CyE450Lb6foj4dOnT7OsCCGEEIu8stYSwsHBASBba5m5qoo337yq6/WdO3em0/Hd+3frun79+uXZ2Vni8PTpVwcH0zx3JydHTdN88+zrsspzgL7vjTHT6bQo87atQ+ibpgkhWGuzLBPBGFIIqWkakM7lGcfEzMzsvUdIPiYfOJ9UQGSMIWsNCxElkARCAqr4AACFRcR7j8LMzCnF6FNKFinLbFEUWZYdnZ40TdN1XQghhOBD1/e9bjYVifV6HUIYjUZElBeubVuBlGWZI6OLEH3o+x4AyrIUkfl8EUI4Pj4moouLCxa01qrCreu6russy0RkPB7Xda17O8syIkopxRiZue97773esu7YsixVjJ1zRDRsY2stS4wxWgPWWmt1JwMiJh9i4pQSS6Rhl8v7h/ceEVVy9r9YHwYijkYjABiNRnoPw14XEbWBjx8/zvNcf6NXP0ia3ttwV2ru9CnuiyUi6tbUW8L3D/3GuDtCivsHMydhERARRgAAIkLEBNLH4FNcbtYhJUHMigKtuV0uNm2zbuou+Murq+9//rnNsuvb2zzPz8/PP//88+l0aoxxzumNq6LRVTo6Ouq67ujoaL1e3717dzwep5SyLBuNRrPZjJlXq9VqtXLO6T7I81y1qd71crns+17dAedc13UoYoy5vLw8OTl5/vz5D37wg/Ozs6PDwzzPmVkfZmbp5cuXKSUU0MtAxLars8xJ4uOTQ/VN7t69+0//9E+67Z49e6ZKJ8Z4cHDw4ZMnP//5z5tNrdvx0f0Hq/miqde31zd92xhCjoEADdJ6ueqaNoVokMo8K/LMEkYfurZGREEQER9DCMHHwMyM0HVt27ad72OMMaXOt13X6T3GGL33IXh1Urz3XddlWWYt7R7rdqsw87Nnz54/f/7ixYu3b99u6lVKyVqbZS7LsizLdFPpNiADRJQ4OOemo/F0Oi2KAoD188fjMQC0bRtDyLJMDd1yuby+vo4xZllmjCGi8Xh8dnZ2584dAPDe13UdY9RtrIrDWqsbTLdTCAERp9Op6ik9Qfe8un4JpO9DCMHaLM9LvVRh8N7HGEUQwZCezcxpdwwibq0djUZlWaoa0C/QbeScCyGMx2MRybJMnQoR6ft+vV5fXV0h4ocffnh6ejoajVSRqIx1Xafr+L6PalWwU0qDJRzEbBBLRCQilVh98U4g3700QoiI+pCstcZYEBSGFDlF7r0XgN57Y20fvMszm7nRZFxUZe/94dFRluf/5e//7uPvfPKTP/9n63pzfHz8/Pnzf/2v//WDBw9+9atf6dqpsiyKIqWkznnTNKenp9PpVNVn0zSbzeby8lK90BDCb37zG2vt2dlZ3/d1XRdFoU+rbdurq6vlcpmEp9Pp0dGRcw4AiqLo2857//vf/q5erfURTMbjsiy7rh3WIYTgMiPCSGJAdspbPnz85OT0qOub2XTcdvXzF98kDin0PnQI7H2XGTo7OyGU2/n1Yn5jCI4OZ02z6ft+uVyq8VksFl3XuMxYR71v62a92azUThZFUZalc86nkFIMKaaUYkqIaDNXlCUQCoIxJitcVhY2z4xzLs8QkZm93+p0IrKWdD/oHTnnNMbRM+/cuXN0dKSO1aCRdds4ZxCxKAqXGR+6EEJd11VejMtK/UlriQCNMS6zR7ODvu/bth6NRqPRqGmaul7nuauqSpcxxjgej+/fv3/37t3pdHp9fb1er9u2TSk553RzWmuLotDtp0ZPhVDV7mCo9F7UzAij97Hvg6AGVsAhMbP3UZIYpNwVpHel4qcirnKoFl9XRKVFhTDP89VqhYjz+Xwymcznc0TUB6OGmJnrum6aRgVYTZBzTkVCpUtvYJCbQa50lQch3JdD3h3f0hr6sIczhRCAEI2xmTHGGAN7QktEYMgVeQKZHh4kEJtnPsWsKNQe5lXJCJfXV6t644ocDLVtu1gs1Dm/vr6u63qz2YiIOkjOOdU7RHRwcPDJJ58Q0ZMnTz755JOyLL/88stvvvkGAL773e/OZjO9AFVn6l8YY5yzIty2zXg8un//3tnZ6cHBjAhR+PDo4JunX2eZ+9//978cFcXLZ8/KLJtOp23b6uecHh2v18uUUgpxvV7H5HNrEofet5PJ+N6dO9Px+M2b13/ywz/65puvr64u//zPf3JyeHB5eVGV+Xqz/Onf/5c//qPPLSGBfP797/3iZ19Yg9aStSTAXdfU9TrGOBqNzs/PsyyLMa5Wq+Vy7n3nMluWeVnmiCiEYBAIhRhIyFrnXDmqqqoqRuVkNj04OJhOp7OD6eHxgbWGCNWl6vs+hF631mJ5u9lsvPcAbK3VHQIAR0dHBwcHk+moKDPdIQCit8/MbdsmDjHGuq5VIZZlqVYhRh99UIMxLisAcM6URWGtTSl436FAmRcff/zxeDxWpVMUxWQySSldXl7e3Nzoe3UXqQEnIjUY+vtB8FRpMrPamGG3ExGA+rEphMBJhJETxMgxcIoCQIho33dHh+BQ+r4vikz3SkoJgMkYRHRZpjAMEZ2cnCyXyyzLxuNxjHEw6HVdLxYL9bJ0rVXwjDEpSQhBRPKsGELBwdskIpb0LXdUzwkhgKAxKaWESAxRBEXEUhRJzFFEEggyi0hiF
"text/plain": [
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x280 at 0x7FC63A7A82D0>"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "My5Z6p7pQ3UC"
},
"source": [
"### 支持新的数据集\n",
"\n",
"MMClassification 要求数据集必须将图像和标签放在同级目录下。有两种方式可以支持自定义数据集。\n",
"\n",
"最简单的方式就是将数据集转换成现有的数据集格式(比如 ImageNet。另一种方式就是新建一个新的数据集类。细节可以查看 [文档](https://github.com/open-mmlab/mmclassification/blob/master/docs_zh-CN/tutorials/new_dataset.md).\n",
"\n",
"在这个教程中,为了方便学习,我们已经将 “猫狗分类数据集” 按照 ImageNet 的数据集格式进行了整理。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P335gKt9Q5U-"
},
"source": [
"除了图片文件外,数据集还包括以下文件:\n",
"\n",
"1. 类别列表。每行代表一个类别。\n",
" ```\n",
" cats\n",
" dogs\n",
" ```\n",
"2. 训练/验证/测试标签。\n",
"每行包括一个文件名和其相对应的标签。\n",
" ```\n",
" ...\n",
" cats/cat.3769.jpg 0\n",
" cats/cat.882.jpg 0\n",
" ...\n",
" dogs/dog.3881.jpg 1\n",
" dogs/dog.3377.jpg 1\n",
" ...\n",
" ```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BafQ7ijBQ8N_"
},
"source": [
"## 使用 shell 命令进行模型训练和测试\n",
"\n",
"MMCls 同样提供了命令行工具,提供如下功能:\n",
"\n",
"1. 模型训练\n",
"2. 模型微调\n",
"3. 模型测试\n",
"4. 推理计算\n",
"\n",
"模型训练的过程与模型微调的过程一致,我们已经看到 Python API 的推理和模型微调过程。接下来我们将会看到如何使用命令行工具完成这些任务。更过细节可以参考 [文档](https://github.com/open-mmlab/mmclassification/blob/master/docs_zh-CN/getting_started.md)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Aj5cGMihURrZ"
},
"source": [
"### 模型微调\n",
"\n",
"通过命令行进行模型微调步骤如下:\n",
"\n",
"1. 准备自定义数据集\n",
"2. 在 py 脚本中修改配置文件\n",
"3. 使用命令行工具进行模型微调\n",
"\n",
"第 1 步与之前的介绍一致,我们将会介绍后面两个步骤的内容。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wl-FNFP8O0dh"
},
"source": [
"#### 创建一个新的配置文件\n",
"\n",
"为了能够复用不同配置文件中常用的部分,我们支持多配置文件继承。比如模型微调 MobileNetV2 ,新的配置文件可以通过继承 `configs/_base_/models/mobilenet_v2_1x.py` 来创建模型的基本结构。\n",
"\n",
"根据以往的实践,我们通常把完整的配置拆分成四个部分:模型、数据集、优化器、运行设置。每个部分的配置单独保存到一个文件,并放在 `config/_base_` 的对应目录下。\n",
"\n",
"这样一来,在创建新的配置文件时,我们就可以选择继承若干个需要的配置文件,然后覆盖其中需要修改的部分内容。\n",
"\n",
"我们的新配置文件开头的继承部分为:\n",
"\n",
"```python\n",
"_base_ = [\n",
" '../_base_/models/mobilenet_v2_1x.py',\n",
" '../_base_/schedules/imagenet_bs256_epochstep.py',\n",
" '../_base_/default_runtime.py'\n",
"]\n",
"```\n",
"\n",
"这里,因为我们使用了一个新的数据集,所以没有继承任何数据集相关的配置。\n",
"\n",
"此外,也可以不使用这种继承的方式,而直接构建完整的配置文件,比如 `configs/mnist/lenet5.py`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_UV3oBhLRG8B"
},
"source": [
"之后,我们只需要设定配置文件中我们希望修改的部分,其他部分的设置会自动从继承的配置文件中读取。"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8QfM4qBeWIQh",
"outputId": "0e658dca-722e-4bed-dd0b-601731b00457"
},
"source": [
"%%writefile configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\n",
"_base_ = [\n",
" '../_base_/models/mobilenet_v2_1x.py',\n",
" '../_base_/schedules/imagenet_bs256_epochstep.py',\n",
" '../_base_/default_runtime.py'\n",
"]\n",
"\n",
"# ---- 模型配置 ----\n",
"# 这里使用 init_cfg 来加载预训练模型,通过这种方式,只有主干网络的权重会被加载。\n",
"# 另外还修改了分类头部的 num_classes 来匹配我们的数据集。\n",
"\n",
"model = dict(\n",
" backbone=dict(\n",
" init_cfg = dict(\n",
" type='Pretrained', \n",
" checkpoint='https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', \n",
" prefix='backbone')\n",
" ),\n",
" head=dict(\n",
" num_classes=2,\n",
" topk = (1, )\n",
" ))\n",
"\n",
"# ---- 数据集配置 ----\n",
"# 我们已经将数据集重新组织为 ImageNet 格式\n",
"dataset_type = 'ImageNet'\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.050, 106.438],\n",
" std=[58.577, 57.310, 57.437],\n",
" to_rgb=True)\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224, backend='pillow'),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(type='Normalize', **img_norm_cfg),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1), backend='pillow'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(type='Normalize', **img_norm_cfg),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
"]\n",
"data = dict(\n",
" # 设置每个 GPU 上的 batch size 和 workers 数, 根据你的硬件来修改这些选项。\n",
" samples_per_gpu=32,\n",
" workers_per_gpu=2,\n",
" # 指定训练集类型和路径\n",
" train=dict(\n",
" type=dataset_type,\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=train_pipeline),\n",
" # 指定验证集类型和路径\n",
" val=dict(\n",
" type=dataset_type,\n",
" data_prefix='data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file='data/cats_dogs_dataset/val.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=test_pipeline),\n",
" # 指定测试集类型和路径\n",
" test=dict(\n",
" type=dataset_type,\n",
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file='data/cats_dogs_dataset/test.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=test_pipeline))\n",
"\n",
"# 设置验证指标\n",
"evaluation = dict(metric='accuracy', metric_options={'topk': (1, )})\n",
"\n",
"# ---- 优化器设置 ----\n",
"# 通常在微调任务中,我们需要一个较小的学习率,训练轮次可以较短。\n",
"# 设置学习率\n",
"optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"# 设置学习率调度器\n",
"lr_config = dict(policy='step', step=1, gamma=0.1)\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"\n",
"# ---- 运行设置 ----\n",
"# 每 10 个训练批次输出一次日志\n",
"log_config = dict(interval=10)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Writing configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "chLX7bL3RP2F"
},
"source": [
"#### 使用命令行进行模型微调\n",
"\n",
"我们使用 `tools/train.py` 进行模型微调:\n",
"\n",
"```\n",
"python tools/train.py ${CONFIG_FILE} [optional arguments]\n",
"```\n",
"\n",
"如果你希望指定训练过程中相关文件的保存位置,可以增加一个参数 `--work_dir ${YOUR_WORK_DIR}`.\n",
"\n",
"通过增加参数 `--seed ${SEED}`,设置随机种子以保证结果的可重复性,而参数 `--deterministic`则会启用 cudnn 的确定性选项,进一步保证可重复性,但可能降低些许效率。\n",
"\n",
"这里我们使用 `MobileNetV2` 和数据集 `CatsDogsDataset` 作为示例"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gbFGR4SBRUYN",
"outputId": "66019f0f-2ded-4fae-9a5f-ece9729a7c2d"
},
"source": [
"!python tools/train.py \\\n",
" configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py \\\n",
" --work-dir work_dirs/mobilenet_v2_1x_cats_dogs \\\n",
" --seed 0 \\\n",
" --deterministic"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"2021-10-21 02:53:42,465 - mmcls - INFO - Environment info:\n",
"------------------------------------------------------------\n",
"sys.platform: linux\n",
"Python: 3.7.12 (default, Sep 10 2021, 00:21:48) [GCC 7.5.0]\n",
"CUDA available: True\n",
"GPU 0: Tesla K80\n",
"CUDA_HOME: /usr/local/cuda\n",
"NVCC: Build cuda_11.1.TC455_06.29190527_0\n",
"GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
"PyTorch: 1.9.0+cu111\n",
"PyTorch compiling details: PyTorch built with:\n",
" - GCC 7.3\n",
" - C++ Version: 201402\n",
" - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n",
" - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n",
" - OpenMP 201511 (a.k.a. OpenMP 4.5)\n",
" - NNPACK is enabled\n",
" - CPU capability usage: AVX2\n",
" - CUDA Runtime 11.1\n",
" - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n",
" - CuDNN 8.0.5\n",
" - Magma 2.5.2\n",
" - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n",
"\n",
"TorchVision: 0.10.0+cu111\n",
"OpenCV: 4.1.2\n",
"MMCV: 1.3.15\n",
"MMCV Compiler: n/a\n",
"MMCV CUDA Compiler: n/a\n",
"MMClassification: 0.16.0+77a3834\n",
"------------------------------------------------------------\n",
"\n",
"2021-10-21 02:53:42,465 - mmcls - INFO - Distributed training: False\n",
"2021-10-21 02:53:43,086 - mmcls - INFO - Config:\n",
"model = dict(\n",
" type='ImageClassifier',\n",
" backbone=dict(\n",
" type='MobileNetV2',\n",
" widen_factor=1.0,\n",
" init_cfg=dict(\n",
" type='Pretrained',\n",
" checkpoint=\n",
" 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth',\n",
" prefix='backbone')),\n",
" neck=dict(type='GlobalAveragePooling'),\n",
" head=dict(\n",
" type='LinearClsHead',\n",
" num_classes=2,\n",
" in_channels=1280,\n",
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
" topk=(1, )))\n",
"optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"lr_config = dict(policy='step', gamma=0.1, step=1)\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"checkpoint_config = dict(interval=1)\n",
"log_config = dict(interval=10, hooks=[dict(type='TextLoggerHook')])\n",
"dist_params = dict(backend='nccl')\n",
"log_level = 'INFO'\n",
"load_from = None\n",
"resume_from = None\n",
"workflow = [('train', 1)]\n",
"dataset_type = 'ImageNet'\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.05, 106.438], std=[58.577, 57.31, 57.437], to_rgb=True)\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224, backend='pillow'),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1), backend='pillow'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
"]\n",
"data = dict(\n",
" samples_per_gpu=32,\n",
" workers_per_gpu=2,\n",
" train=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224, backend='pillow'),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
" ]),\n",
" val=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file='data/cats_dogs_dataset/val.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1), backend='pillow'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ]),\n",
" test=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file='data/cats_dogs_dataset/test.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1), backend='pillow'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[124.508, 116.05, 106.438],\n",
" std=[58.577, 57.31, 57.437],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ]))\n",
"evaluation = dict(metric='accuracy', metric_options=dict(topk=(1, )))\n",
"work_dir = 'work_dirs/mobilenet_v2_1x_cats_dogs'\n",
"gpu_ids = range(0, 1)\n",
"\n",
"2021-10-21 02:53:43,086 - mmcls - INFO - Set random seed to 0, deterministic: True\n",
"2021-10-21 02:53:43,251 - mmcls - INFO - initialize MobileNetV2 with init_cfg {'type': 'Pretrained', 'checkpoint': 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', 'prefix': 'backbone'}\n",
"2021-10-21 02:53:43,252 - mmcv - INFO - load backbone in model from: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"Use load_from_http loader\n",
"Downloading: \"https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\" to /root/.cache/torch/hub/checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"100% 13.5M/13.5M [00:01<00:00, 9.62MB/s]\n",
"2021-10-21 02:53:46,164 - mmcls - INFO - initialize LinearClsHead with init_cfg {'type': 'Normal', 'layer': 'Linear', 'std': 0.01}\n",
"2021-10-21 02:54:01,365 - mmcls - INFO - Start running, host: root@3a8df14fab46, work_dir: /content/mmclassification/work_dirs/mobilenet_v2_1x_cats_dogs\n",
"2021-10-21 02:54:01,365 - mmcls - INFO - Hooks will be executed in the following order:\n",
"before_run:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(NORMAL ) CheckpointHook \n",
"(LOW ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_train_epoch:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(LOW ) IterTimerHook \n",
"(LOW ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_train_iter:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(LOW ) IterTimerHook \n",
"(LOW ) EvalHook \n",
" -------------------- \n",
"after_train_iter:\n",
"(ABOVE_NORMAL) OptimizerHook \n",
"(NORMAL ) CheckpointHook \n",
"(LOW ) IterTimerHook \n",
"(LOW ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"after_train_epoch:\n",
"(NORMAL ) CheckpointHook \n",
"(LOW ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_val_epoch:\n",
"(LOW ) IterTimerHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_val_iter:\n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_val_iter:\n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_val_epoch:\n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"2021-10-21 02:54:01,365 - mmcls - INFO - workflow: [('train', 1)], max: 2 epochs\n",
"2021-10-21 02:54:07,010 - mmcls - INFO - Epoch [1][10/201]\tlr: 5.000e-03, eta: 0:03:34, time: 0.548, data_time: 0.260, memory: 1709, loss: 0.3917\n",
"2021-10-21 02:54:09,888 - mmcls - INFO - Epoch [1][20/201]\tlr: 5.000e-03, eta: 0:02:39, time: 0.288, data_time: 0.021, memory: 1709, loss: 0.3508\n",
"2021-10-21 02:54:12,795 - mmcls - INFO - Epoch [1][30/201]\tlr: 5.000e-03, eta: 0:02:19, time: 0.291, data_time: 0.020, memory: 1709, loss: 0.3955\n",
"2021-10-21 02:54:15,744 - mmcls - INFO - Epoch [1][40/201]\tlr: 5.000e-03, eta: 0:02:08, time: 0.295, data_time: 0.019, memory: 1709, loss: 0.2485\n",
"2021-10-21 02:54:18,667 - mmcls - INFO - Epoch [1][50/201]\tlr: 5.000e-03, eta: 0:02:00, time: 0.292, data_time: 0.021, memory: 1709, loss: 0.4196\n",
"2021-10-21 02:54:21,590 - mmcls - INFO - Epoch [1][60/201]\tlr: 5.000e-03, eta: 0:01:54, time: 0.293, data_time: 0.022, memory: 1709, loss: 0.4994\n",
"2021-10-21 02:54:24,496 - mmcls - INFO - Epoch [1][70/201]\tlr: 5.000e-03, eta: 0:01:48, time: 0.291, data_time: 0.021, memory: 1709, loss: 0.4372\n",
"2021-10-21 02:54:27,400 - mmcls - INFO - Epoch [1][80/201]\tlr: 5.000e-03, eta: 0:01:44, time: 0.290, data_time: 0.020, memory: 1709, loss: 0.3179\n",
"2021-10-21 02:54:30,313 - mmcls - INFO - Epoch [1][90/201]\tlr: 5.000e-03, eta: 0:01:39, time: 0.292, data_time: 0.020, memory: 1709, loss: 0.3175\n",
"2021-10-21 02:54:33,208 - mmcls - INFO - Epoch [1][100/201]\tlr: 5.000e-03, eta: 0:01:35, time: 0.289, data_time: 0.020, memory: 1709, loss: 0.3412\n",
"2021-10-21 02:54:36,129 - mmcls - INFO - Epoch [1][110/201]\tlr: 5.000e-03, eta: 0:01:31, time: 0.292, data_time: 0.021, memory: 1709, loss: 0.2985\n",
"2021-10-21 02:54:39,067 - mmcls - INFO - Epoch [1][120/201]\tlr: 5.000e-03, eta: 0:01:28, time: 0.294, data_time: 0.021, memory: 1709, loss: 0.2778\n",
"2021-10-21 02:54:41,963 - mmcls - INFO - Epoch [1][130/201]\tlr: 5.000e-03, eta: 0:01:24, time: 0.289, data_time: 0.020, memory: 1709, loss: 0.2229\n",
"2021-10-21 02:54:44,861 - mmcls - INFO - Epoch [1][140/201]\tlr: 5.000e-03, eta: 0:01:21, time: 0.290, data_time: 0.021, memory: 1709, loss: 0.2318\n",
"2021-10-21 02:54:47,782 - mmcls - INFO - Epoch [1][150/201]\tlr: 5.000e-03, eta: 0:01:17, time: 0.293, data_time: 0.020, memory: 1709, loss: 0.2333\n",
"2021-10-21 02:54:50,682 - mmcls - INFO - Epoch [1][160/201]\tlr: 5.000e-03, eta: 0:01:14, time: 0.290, data_time: 0.020, memory: 1709, loss: 0.2783\n",
"2021-10-21 02:54:53,595 - mmcls - INFO - Epoch [1][170/201]\tlr: 5.000e-03, eta: 0:01:11, time: 0.291, data_time: 0.019, memory: 1709, loss: 0.2132\n",
"2021-10-21 02:54:56,499 - mmcls - INFO - Epoch [1][180/201]\tlr: 5.000e-03, eta: 0:01:07, time: 0.290, data_time: 0.021, memory: 1709, loss: 0.2096\n",
"2021-10-21 02:54:59,381 - mmcls - INFO - Epoch [1][190/201]\tlr: 5.000e-03, eta: 0:01:04, time: 0.288, data_time: 0.023, memory: 1709, loss: 0.1729\n",
"2021-10-21 02:55:02,270 - mmcls - INFO - Epoch [1][200/201]\tlr: 5.000e-03, eta: 0:01:01, time: 0.288, data_time: 0.020, memory: 1709, loss: 0.1969\n",
"2021-10-21 02:55:02,313 - mmcls - INFO - Saving checkpoint at 1 epochs\n",
"[ ] 0/1601, elapsed: 0s, ETA:[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)\n",
"[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)\n",
"[>>] 1601/1601, 171.8 task/s, elapsed: 9s, ETA: 0s2021-10-21 02:55:11,743 - mmcls - INFO - Epoch(val) [1][51]\taccuracy_top-1: 95.6277\n",
"2021-10-21 02:55:16,920 - mmcls - INFO - Epoch [2][10/201]\tlr: 5.000e-04, eta: 0:00:59, time: 0.501, data_time: 0.237, memory: 1709, loss: 0.1764\n",
"2021-10-21 02:55:19,776 - mmcls - INFO - Epoch [2][20/201]\tlr: 5.000e-04, eta: 0:00:56, time: 0.286, data_time: 0.021, memory: 1709, loss: 0.1514\n",
"2021-10-21 02:55:22,637 - mmcls - INFO - Epoch [2][30/201]\tlr: 5.000e-04, eta: 0:00:52, time: 0.286, data_time: 0.019, memory: 1709, loss: 0.1395\n",
"2021-10-21 02:55:25,497 - mmcls - INFO - Epoch [2][40/201]\tlr: 5.000e-04, eta: 0:00:49, time: 0.286, data_time: 0.020, memory: 1709, loss: 0.1508\n",
"2021-10-21 02:55:28,338 - mmcls - INFO - Epoch [2][50/201]\tlr: 5.000e-04, eta: 0:00:46, time: 0.284, data_time: 0.018, memory: 1709, loss: 0.1771\n",
"2021-10-21 02:55:31,214 - mmcls - INFO - Epoch [2][60/201]\tlr: 5.000e-04, eta: 0:00:43, time: 0.287, data_time: 0.019, memory: 1709, loss: 0.1438\n",
"2021-10-21 02:55:34,075 - mmcls - INFO - Epoch [2][70/201]\tlr: 5.000e-04, eta: 0:00:40, time: 0.286, data_time: 0.020, memory: 1709, loss: 0.1321\n",
"2021-10-21 02:55:36,921 - mmcls - INFO - Epoch [2][80/201]\tlr: 5.000e-04, eta: 0:00:36, time: 0.285, data_time: 0.023, memory: 1709, loss: 0.1629\n",
"2021-10-21 02:55:39,770 - mmcls - INFO - Epoch [2][90/201]\tlr: 5.000e-04, eta: 0:00:33, time: 0.285, data_time: 0.018, memory: 1709, loss: 0.1574\n",
"2021-10-21 02:55:42,606 - mmcls - INFO - Epoch [2][100/201]\tlr: 5.000e-04, eta: 0:00:30, time: 0.284, data_time: 0.019, memory: 1709, loss: 0.1220\n",
"2021-10-21 02:55:45,430 - mmcls - INFO - Epoch [2][110/201]\tlr: 5.000e-04, eta: 0:00:27, time: 0.282, data_time: 0.021, memory: 1709, loss: 0.2550\n",
"2021-10-21 02:55:48,280 - mmcls - INFO - Epoch [2][120/201]\tlr: 5.000e-04, eta: 0:00:24, time: 0.285, data_time: 0.021, memory: 1709, loss: 0.1528\n",
"2021-10-21 02:55:51,131 - mmcls - INFO - Epoch [2][130/201]\tlr: 5.000e-04, eta: 0:00:21, time: 0.285, data_time: 0.020, memory: 1709, loss: 0.1223\n",
"2021-10-21 02:55:53,983 - mmcls - INFO - Epoch [2][140/201]\tlr: 5.000e-04, eta: 0:00:18, time: 0.285, data_time: 0.019, memory: 1709, loss: 0.1734\n",
"2021-10-21 02:55:56,823 - mmcls - INFO - Epoch [2][150/201]\tlr: 5.000e-04, eta: 0:00:15, time: 0.284, data_time: 0.022, memory: 1709, loss: 0.1527\n",
"2021-10-21 02:55:59,645 - mmcls - INFO - Epoch [2][160/201]\tlr: 5.000e-04, eta: 0:00:12, time: 0.283, data_time: 0.021, memory: 1709, loss: 0.1910\n",
"2021-10-21 02:56:02,514 - mmcls - INFO - Epoch [2][170/201]\tlr: 5.000e-04, eta: 0:00:09, time: 0.287, data_time: 0.019, memory: 1709, loss: 0.1922\n",
"2021-10-21 02:56:05,375 - mmcls - INFO - Epoch [2][180/201]\tlr: 5.000e-04, eta: 0:00:06, time: 0.286, data_time: 0.018, memory: 1709, loss: 0.1760\n",
"2021-10-21 02:56:08,241 - mmcls - INFO - Epoch [2][190/201]\tlr: 5.000e-04, eta: 0:00:03, time: 0.287, data_time: 0.019, memory: 1709, loss: 0.1739\n",
"2021-10-21 02:56:11,081 - mmcls - INFO - Epoch [2][200/201]\tlr: 5.000e-04, eta: 0:00:00, time: 0.282, data_time: 0.019, memory: 1709, loss: 0.1654\n",
"2021-10-21 02:56:11,125 - mmcls - INFO - Saving checkpoint at 2 epochs\n",
"[>>] 1601/1601, 170.9 task/s, elapsed: 9s, ETA: 0s2021-10-21 02:56:20,592 - mmcls - INFO - Epoch(val) [2][51]\taccuracy_top-1: 97.5016\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m_ZSkwB5Rflb"
},
"source": [
"### 测试模型\n",
"\n",
"使用 `tools/test.py` 对模型进行测试:\n",
"\n",
"```\n",
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]\n",
"```\n",
"\n",
"这里有一些可选参数可以进行配置:\n",
"\n",
"- `--metrics`: 评价指标。可以在数据集类中找到所有可用的选择,一般对单标签分类任务,我们都可以使用 \"accuracy\" 进行评价。\n",
"- `--metric-options`: 传递给评价指标的自定义参数。比如指定了 \"topk=1\",那么就会计算 \"top-1 accuracy\"。\n",
"\n",
"更多细节请参看 `tools/test.py` 的帮助文档。\n",
"\n",
"这里使用我们微调好的 `MobileNetV2` 模型进行测试"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zd4EM00QRtyc",
"outputId": "e0be9ba6-47f5-45d9-cca2-d2c5a38b1407"
},
"source": [
"!python tools/test.py configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py work_dirs/mobilenet_v2_1x_cats_dogs/latest.pth --metrics accuracy --metric-options topk=1"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"Use load_from_local loader\n",
"[>>] 2023/2023, 169.7 task/s, elapsed: 12s, ETA: 0s\n",
"accuracy : 97.38\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IwThQkjaRwF7"
},
"source": [
"### 推理计算\n",
"\n",
"有时我们会希望保存模型在数据集上的推理结果,可以使用如下命令:\n",
"\n",
"```shell\n",
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}]\n",
"```\n",
"\n",
"参数:\n",
"\n",
"- `--out`: 输出结果的文件名。如果不指定计算结果不会被保存。支持的格式包括json, pkl 和 yml\n",
"- `--out-items`: 哪些推理结果需要被保存,可以从 \"class_scores\", \"pred_score\", \"pred_label\" 和 \"pred_class\" 中选择若干个,或者使用 \"all\" 来保存所有推理结果。\n",
"\n",
"这些项的具体含义:\n",
"- `class_scores`: 各个样本在每个类上的分类得分。\n",
"- `pred_score`: 各个样本在预测类上的分类得分。\n",
"- `pred_label`: 各个样本预测类的标签。标签文本将会从模型权重文件中读取,如果模型权重文件中没有标签文本,则会使用 ImageNet 的标签文本。\n",
"- `pred_class`: 各个样本预测类的 id为一组整数。\n",
"- `all`: 保存以上所有项。\n",
"- `none`: 不保存以上任何项。因为输出文件除了推理结果,还会保存评价指标,如果你只希望保存总体评价指标,可以设置不保存任何项,可以大幅减小输出文件大小。"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6GVKloPHR0Fn",
"outputId": "1efde0e4-97cd-4e62-ce98-1cbc79da3a6c"
},
"source": [
"!python tools/test.py configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py work_dirs/mobilenet_v2_1x_cats_dogs/latest.pth --out results.json --out-items all"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"Use load_from_local loader\n",
"[>>] 2023/2023, 170.3 task/s, elapsed: 12s, ETA: 0s\n",
"dumping results to results.json\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G0NJI1s6e3FD"
},
"source": [
"导出的json 文件中保存了所有样本的推理结果、分类结果和分类得分"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 370
},
"id": "HJdJeLUafFhX",
"outputId": "486c0652-2124-419a-ec7d-fd3583baedb1"
},
"source": [
"import json\n",
"\n",
"with open(\"./results.json\", 'r') as f:\n",
" results = json.load(f)\n",
"\n",
"# 展示第一张图片的结果信息\n",
"print('class_scores:', results['class_scores'][0])\n",
"print('pred_class:', results['pred_class'][0])\n",
"print('pred_label:', results['pred_label'][0])\n",
"print('pred_score:', results['pred_score'][0])\n",
"Image.open('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"class_scores: [1.0, 5.184615757547473e-13]\n",
"pred_class: cats\n",
"pred_label: 0\n",
"pred_score: 1.0\n"
]
},
{
"output_type": "execute_result",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEYCAIAAABp9FyZAAEAAElEQVR4nJT96ZMkWXIfCKrqe88uv+KOvCuzju7qrmo00A00QZAcoZAznPmwIkP+p/NhZWVnZTAHSBALAmig+u6ursrKMzIyDj/teofqflB3S89qckTWpCTKM8Lc3ezZ0+unP1XF737+KQAgGEQDQCAoIgD04MGDul7f3Nx0fTsalbPZJCa/XC5dOQ4hGGMAoGka55y1tu/7LMu89865zWbz+PHj+Xy+2WxGo1G3rokIAJgZEa3N9L3C4Jwzxngfu64TkbIsq6pqu40xJqXU922MkYiMMSKS56Uh61zubGmtM5QjGkTMDZRl2fd9CKGoSkTx3gthSgkRBCGllFIKnAAAEbGXPM/1q7vOM3OM0ZBLKcWYUkrM/PHH33n58mWe5yklnlW3t7e5M6fHJ7PRaLOY3769zAxNsvz06PDJwwe5ofnV1eNHD2+uru7du3O7aeu6Nha97/7n//l/fvHimQ8dIi4Wi/V6vVk34/H0+voa0Xzy8XdfvXrVtt39+/c//vjjt2/fPn32TZ7nDBKZY4x37997+/btn/zpj9++fftX//E//vCHP1yv1yWS98FaK2i+//3vv3nztutDluWcQMjEwNfXN/cfPHrz5o0IVlVlQvfpp5+mlH72s5+NZ9MY48nZ6Xy5XK1Wn37/e3fu3Hnx6uVvf/vbPoQ7d+6cn5/317ezg8lqteq6phqVV1eXm83q0aNHTdOklFKSlASBrM2yLHM2ny9eiAgB5nleFBWhjZG7kMAUybibTXdxs1yHaMspOtf7aAiMMdZaRERhRHTOOUNd13FKIfQxRokJUay1xpibxTzGKCLWWiJCkizL8jy/vb1JKT169Kiu133fHx8f397etm07nU67rqubdex9nueTycQY5JjeXF4cTmfHx8fG4Pz2drVaTSaTO3furNb1YrHw3mdZllKq6xoRx+NxWZbM7JyLMW42G+dcVVUppc1mlee5916/6/T0tGmaZ8+e3bt3DxHLsgSAvu+dcyIyn8/J8PHxIVFs13Mi/z/+9/9qOrWzcX7n7lEKTQi9QLKICAAg8K2j7/uUEiKKSAih67rEIYRQ94uiKEII0+n06urq5OQkhAAAMUa9iJQSAFRVZa1dr9ciwszGGN36RDal5L1HIGYGABHRaxARXe5vXwrAVoaBmVl2B4CIiE/RGOO9DyGQNYMQhhCICA3q240xiGj0f4jW2rqui6Lquq4sS9/HoiiOjo5fvXo1mUx+/OMf379//2/+5m8ePnz4y9fPfvSjH/3t3/y1M/aPP//8F/+0qKpqeXuTCSBi0zTZeDSbzbquy/P86uqqmB5WVeUyU9dS1/XV1VVMvqqqs7Oz5XK5XC6rajydTlerze3tbUppPB4bY4qi0KUmIu9749xqtdKlOz8/J6I/+7M/CyE450Lb6foj4dOnT7OsCCGEEIu8stYSwsHBASBba5m5qoo337yq6/WdO3em0/Hd+3frun79+uXZ2Vni8PTpVwcH0zx3JydHTdN88+zrsspzgL7vjTHT6bQo87atQ+ibpgkhWGuzLBPBGFIIqWkakM7lGcfEzMzsvUdIPiYfOJ9UQGSMIWsNCxElkARCAqr4AACFRcR7j8LMzCnF6FNKFinLbFEUWZYdnZ40TdN1XQghhOBD1/e9bjYVifV6HUIYjUZElBeubVuBlGWZI6OLEH3o+x4AyrIUkfl8EUI4Pj4moouLCxa01qrCreu6russy0RkPB7Xda17O8syIkopxRiZue97773esu7YsixVjJ1zRDRsY2stS4wxWgPWWmt1JwMiJh9i4pQSS6Rhl8v7h/ceEVVy9r9YHwYijkYjABiNRnoPw14XEbWBjx8/zvNcf6NXP0ia3ttwV2ru9CnuiyUi6tbUW8L3D/3GuDtCivsHMydhERARRgAAIkLEBNLH4FNcbtYhJUHMigKtuV0uNm2zbuou+Murq+9//rnNsuvb2zzPz8/PP//88+l0aoxxzumNq6LRVTo6Ouq67ujoaL1e3717dzwep5SyLBuNRrPZjJlXq9VqtXLO6T7I81y1qd71crns+17dAedc13UoYoy5vLw8OTl5/vz5D37wg/Ozs6PDwzzPmVkfZmbp5cuXKSUU0MtAxLars8xJ4uOTQ/VN7t69+0//9E+67Z49e6ZKJ8Z4cHDw4ZMnP//5z5tNrdvx0f0Hq/miqde31zd92xhCjoEADdJ6ueqaNoVokMo8K/LMEkYfurZGREEQER9DCMHHwMyM0HVt27ad72OMMaXOt13X6T3GGL33IXh1Urz3XddlWWYt7R7rdqsw87Nnz54/f/7ixYu3b99u6lVKyVqbZS7LsizLdFPpNiADRJQ4OOemo/F0Oi2KAoD188fjMQC0bRtDyLJMDd1yuby+vo4xZllmjCGi8Xh8dnZ2584dAPDe13UdY9RtrIrDWqsbTLdTCAERp9Op6ik9Qfe8un4JpO9DCMHaLM9LvVRh8N7HGEUQwZCezcxpdwwibq0djUZlWaoa0C/QbeScCyGMx2MRybJMnQoR6ft+vV5fXV0h4ocffnh6ejoajVSRqIx1Xafr+L6PalWwU0qDJRzEbBBLRCQilVh98U4g3700QoiI+pCstcZYEBSGFDlF7r0XgN57Y20fvMszm7nRZFxUZe/94dFRluf/5e//7uPvfPKTP/9n63pzfHz8/Pnzf/2v//WDBw9+9atf6dqpsiyKIqWkznnTNKenp9PpVNVn0zSbzeby8lK90BDCb37zG2vt2dlZ3/d1XRdFoU+rbdurq6vlcpmEp9Pp0dGRcw4AiqLo2857//vf/q5erfURTMbjsiy7rh3WIYTgMiPCSGJAdspbPnz85OT0qOub2XTcdvXzF98kDin0PnQI7H2XGTo7OyGU2/n1Yn5jCI4OZ02z6ft+uVyq8VksFl3XuMxYR71v62a92azUThZFUZalc86nkFIMKaaUYkqIaDNXlCUQCoIxJitcVhY2z4xzLs8QkZm93+p0IrKWdD/oHTnnNMbRM+/cuXN0dKSO1aCRdds4ZxCxKAqXGR+6EEJd11VejMtK/UlriQCNMS6zR7ODvu/bth6NRqPRqGmaul7nuauqSpcxxjgej+/fv3/37t3pdHp9fb1er9u2TSk553RzWmuLotDtp0ZPhVDV7mCo9F7UzAij97Hvg6AGVsAhMbP3UZIYpNwVpHel4qcirnKoFl9XRKVFhTDP89VqhYjz+Xwymcznc0TUB6OGmJnrum6aRgVYTZBzTkVCpUtvYJCbQa50lQch3JdD3h3f0hr6sIczhRCAEI2xmTHGGAN7QktEYMgVeQKZHh4kEJtnPsWsKNQe5lXJCJfXV6t644ocDLVtu1gs1Dm/vr6u63qz2YiIOkjOOdU7RHRwcPDJJ58Q0ZMnTz755JOyLL/88stvvvkGAL773e/OZjO9AFVn6l8YY5yzIty2zXg8un//3tnZ6cHBjAhR+PDo4JunX2eZ+9//978cFcXLZ8/KLJtOp23b6uecHh2v18uUUgpxvV7H5HNrEofet5PJ+N6dO9Px+M2b13/ywz/65puvr64u//zPf3JyeHB5eVGV+Xqz/Onf/5c//qPPLSGBfP797/3iZ19Yg9aStSTAXdfU9TrGOBqNzs/PsyyLMa5Wq+Vy7n3nMluWeVnmiCiEYBAIhRhIyFrnXDmqqqoqRuVkNj04OJhOp7OD6eHxgbWGCNWl6vs+hF631mJ5u9lsvPcAbK3VHQIAR0dHBwcHk+moKDPdIQCit8/MbdsmDjHGuq5VIZZlqVYhRh99UIMxLisAcM6URWGtTSl436FAmRcff/zxeDxWpVMUxWQySSldXl7e3Nzoe3UXqQEnIjUY+vtB8FRpMrPamGG3ExGA+rEphMBJhJETxMgxcIoCQIho33dHh+BQ+r4vikz3SkoJgMkYRHRZpjAMEZ2cnCyXyyzLxuNxjHEw6HVdLxYL9bJ0rVXwjDEpSQhBRPKsGELBwdskIpb0LXdUzwkhgKAxKaWESAxRBEXEUhRJzFFEEggyi0hiF
"text/plain": [
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x280 at 0x7FC639727950>"
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1bEUwwzcVG8o"
},
"source": [
"也可以使用 MMClassification 提供的可视化函数 imshow_infos 更好地展示预测结果。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "BcSNyvAWRx20",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 304
},
"outputId": "1db811f7-9637-44c4-8aec-330f5765e20c"
},
"source": [
"from mmcls.core.visualization import imshow_infos\n",
"\n",
"filepath = 'data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg'\n",
"\n",
"result = {\n",
" 'pred_class': results['pred_class'][0],\n",
" 'pred_label': results['pred_label'][0],\n",
" 'pred_score': results['pred_score'][0],\n",
"}\n",
"\n",
"img = imshow_infos(filepath, result)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATMAAAEfCAYAAAAtNiETAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAFiQAABYkBbWid+gAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9d7Rk2VXm+TvmmjDPv8yXmZWmSt4hUDWSCiEE3QgBYnCSMMPMYgYaqUFIgxHQTTdDQ4uhEaIb0EA3vqGBxkiAJOQwLYMTKpUk5MtlVnr/TNhrjps/zr03IrMk/pq1ZvWsilqxXuWLeBFx7z3n29/+9rd3iBACj90euz12e+z2P/pN/n/9AR67PXZ77PbY7f+N22Ng9tjtsdtjt/9f3PTtv3jK5zw1AAgkIBBCA4LgBSDY3t6mrktGoxHzYkKapgyHfXywjCcTnIMkzwkhsLKywuXLl9na2sJai/ceIQQAw+GQ2WzG4cOHqeuauq6ZTCaE2gKglEJrTZIkSCnxHuq6BgRaaaTUOOeo65oQAmmakmUJxhiUlnhvqaoKay1CCJRS8YB1ipKaJMlIkz5aJyiVIESCICAD9PKUuq4xxpDlGUJKjKkJImCa1xNKIIQghEAIAaRASwVeIIMgTVPKsiRJsviZpMZah1KKtbU19vb2SZKE5zznHqy1fPjDHyYf9Lk22uf4Xac498gjiOD4J8+6m73rV9m/cYPx/h49pTh2aJudQ9tsrqwggUwr8A5nHVmvjweEEtR1yZOe9ESqqmQ02iNJE1ZXV9nd3eXChUvkWU6a5sznJVmax+scYHV1lVOnThFC4Oy5sxRlCQICUFQVw5UVhBTc+bi7OHT4MBcvXWJ/f5+qqqiLEhVACIEPgTTNyLIeztMcf0Ka9ZBSU9cWhMA6z2w6J0sznKnItOSJj388ZVlyMB7RHwzY39+ntoY0y5gVczY3N9k5eoTdvT12d3cZjUYIIbhj5ygJgkQLQnDoROO9Y3f3BlVVkmUZ3nu0jmvI+4A1Hmsdcd17pHQ4bxEBtNYoleCDpDYOJxRpb5U6SHYnBdfHEwoLMuvhUBjrUBKklEgpIQQIHojMIQSP9x7nLM4avPckUpGmKXmekmUZOkuZz+dUVYUxBmOrbp1LKanruK43NtaYTqd479ne3qKsirjPHIQQcD6+fqo0Wmu8NVRVybyYs7OzQz/LmUwmWFPR7/cRQlAUJYG4tqWUOOcYjUZMJnGvb2xscOjQIWazGd57QghYa5FSopTCWktRzBBCUNc1zjlWV1fZ3Nykqipu3rzJYDAgTVPW1tYQQjCfz7u9NJlMCDjW1gZoGXBmgtaOf/qC5zIcajbWemxvDLGuwNqaEBzP+6L/XXxGMGvBhkZKazW1VlprASXLMowtCSFQVRUBh/ceYzy9ocZay2AwAGAwGDCbzQCw1hJCQClFCIHpdMr6+jpHjx7loYcewriAcw7nXFwMzWew1mGMIdEpzjm8p7u4IQS899R13Z2U9jiEWIDO8jE+6t485qzFWolxFuMs0iqkjP8OxPdBClSI4OjjCyKlbDYw1M6Ck0zmMwb9eAxZluFFzWg8RiaayXxGr9djd3+Ppz/9GTz48MPsjvYRQrC1tUWepuzv3iBNE5IkodfrMRvH98zynOAjiK6urlIWM4b9AXt7e2yvriCkYjIdo5Si1+uRZSnT2Zj5fN6dzzzLkFKidUKWBYJfnLP5fN4tMGstSilqYwgEpBBMp1PW1tcoy5LhcMhdd96JkpLxaMRBVSFCQIkI6tPJhKKoGA5XcdYhUEgEzroGVBJ6WYIUgro25L2cYjLmwoUL3HHHHXjnuXHjBkqrLsDt7OxgnePc2XNkWUY5LzDGsLmxwXg8IhUCLQS9Xk6aJHgfmve0zEyFlArVG5AkGpkoaizeFpRljZCerK8IPuBDwBuD8gCKEMAHT1GWVEFQmQpjLcZBqCtsEHjnydIU71231mRzzoOLgbDXy6mNx7t21XkgAkNVV1y6eoWiiMckpSRJ47EnSYLWGojXynvfrWlrXQeg1tZkWUYis2Zv2C7wV1XNSn+ACIGimGFt3QBpznw+ZTQ6oDaOjY0NhsMhdV0jpWQ4HLK+vs7m5ma316qqIssy0jRtANp1RKQFsnZfGGMQQrCyssJ0Ou3WWgwqETTbfet9xBQSQaJTsqw5RhOfU9c1wTvirl0kl48Cs45pBEkI7ckSECJTs9aS5ymDwQDn4wE55xGyjWIRXLTWGGMYDocdc7LWkiQJVVVRVRWTyYT5fM7GxgZ33XUXo9GIgxu7lGXZMbkInhFMlVJLAOeBBeMC8N7FxdP8rgPmpWNrwau98IvXlBA8QQr4TGDXnvw06S4QgG8ugrOeIMEaixSS2pjI6Jwlz3N0mqASjXEW6x3D1RUSnfCp++/n+S94AXd//j/h3e99D8N+j93dPe7+vGdy/eo6ly9epJiM8dailUJ4T5okzQI2IAR1VTPcPhQXlHW4YOOGq+bs7u4SgmM8HuO9Y29vj8FgwPr6OrPZnKKYo1SCThOCD5RlyWQyYXV1lcFgQJbnZFnGdD6jKAq0lJGBlRXXr1xla2OTwUq8xkmSMOwPmE+nuIb9xKDlESIQL51HyICpasrK0O9LNjc3SdKUy1cug3NoLdnb36XXy7HOUJuKlJQABO9YGQxwwTMaHTCfz5gXc3pZRp6mjPf3cfGJESBwVFWJdYYkUdS1oywLvI8MKc8H6ESThx4gMK7EOYN1EWwVEiEUKkki+KGoHAgEWmuyLCM4gZMaEcBKByICZ7uZU61RUoLQIEAqgXISrxRSRuBVSnXrcmdnp2NDrgHAdl+2azxJIgjkeR6vh6m7bGQwGMQgphOstVRVPF6lFGmWsLa2xmw2IwTHYDAg0boLYFmWopP4eYwxOOdYWVlhMBiwurpKmqacPn2a+XyOcw6tNf1+vwOvNE2RMu8ISZtZtXt/OBwyGo3iZ67rbv+FsCAxgUBdW4IXpAONUqIjNMGBqS1SBbRSKLWAsM8KZu0mbQmNQHVgJkTWRYqqquJibVBWiITRZMLa2hoHBwcMh8PuZ1VV5HneHZz3nqqqmM/n3QU4IFL09iS07ADiBRTIJZa1YEQtO1u+LYPQAvA8XviGIvvmbzwE1x0HYgGEQUCQoqHeEq2S+JiMaZRCxtcSIKQEFdBJgnOOweoKdW1RaYINnizNSPM8nngZmeneaJ+DyRihFVIpSlMhxiOKIqYMBwcH1PMpwVisszhjqU3NIM+x1jGbxTRjOByysbHBZDYn7/c5tLPNlSuXuH79OpPJGKngcY+7i8lk0kXbeM4Caaqa4xbNNYTaVKwmqxw6tE2/32f/YJ8bN13cLP0e0/mU2qacPv0wOzs7zCeTZiP1qeZzqqpGa00vyzHGUhUlQch4PlSCQKBEwJoKpQRbmxuYquTalUv084wsUVy6fIGTJ09y8uRxTj9yhqKYk/dTLl44x9raGiePH+Pc+fNsbazhvefypYusrwzRSLwD6xxFUVDXJUoJhmtrWGu4efNmTOFqh7WB4XCFNEsQEoQJGFeAirJKQBCEQEiB1BqhEqRIkEFiE4dNcxIvsELhEFjrqOZzvBNxc1pLcCayKqnw3nNwcAAhIAgkOjLvJNEICUpJ1jc345rGdYE9rscYjCOLcZSlBwHW1bh5fA9rLYO81wS8gLU1tjYIH8gSTapjoE9TjRQJWim8NZiqRCLo93sc3rmD8XjM/v4+3nvW19cZDodYa9nf32dvb488z0mSBKVUl2VFopMjZUxr2wzMOdfgRCBvgqP3nrKMaX/7Oi0TNQ3LE0BlIEt1lLlczOysBR1CvCZiQWY+I5jdCmyRAbV5ZwSklCRJlqiljxtZRK2oqirSNKaD29vbjMdj0jS9JW1pqetsNuPg4KBDa2NMpJhAv99HKdWkrzHNjNpO+/lu/bxRW/OPSjOXj8cYA0GgVIwCQsSUleZYlaADxkU0XPx/e8KFXkRSKSVSLyJMaWsEkOiMXGlKU0NVUzvLaDJGa01d1VjnEErxkY99lBvXrjMrCuZ1ybHjx
"text/plain": [
"<Figure size 300.01x280.01 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
}
]
}