2021-07-27 13:58:27 +08:00
{
2021-10-15 17:37:12 +08:00
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
2021-07-27 13:58:27 +08:00
"colab": {
2021-10-15 17:37:12 +08:00
"name": "MMClassification_tools_cn.ipynb",
"provenance": [],
"collapsed_sections": []
2021-07-27 13:58:27 +08:00
},
2021-10-15 17:37:12 +08:00
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
2021-07-27 13:58:27 +08:00
}
},
2021-10-15 17:37:12 +08:00
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "XjQxmm04iTx4",
"tags": []
},
"source": [
"<a href=\"https://colab.research.google.com/github/open-mmlab/mmclassification/blob/master/docs_zh-CN/tutorials/MMClassification_tools_cn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "markdown",
"metadata": {
"id": "4z0JDgisPRr-"
},
"source": [
"# 基于 Colab 的 MMClassification tools 教程\n",
"\n",
"在本教程中会介绍如下内容:\n",
"\n",
"* 如何安装 MMCls\n",
"* 数据下载\n",
"* 准备配置文件\n",
"* shell 命令行"
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "markdown",
"metadata": {
"id": "inm7Ciy5PXrU"
},
"source": [
"## 安装 MMClassification\n",
"\n",
"在使用 MMClassification 之前,我们需要配置环境,步骤如下:\n",
"\n",
"- 安装 Python, CUDA, C/C++ compiler 和 git\n",
"- 安装 PyTorch (CUDA 版)\n",
"- 安装 mmcv\n",
"- 克隆 mmcls github 代码库然后安装\n",
"\n",
"因为我们在 Google Colab 进行实验, Colab 已经帮我们完成了基本的配置,我们可以直接跳过前面两个步骤 。"
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "markdown",
"metadata": {
"id": "TDOxbcDvPbNk"
},
"source": [
"### 检查环境"
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c6MbAw10iUJI",
"outputId": "e8582a6e-4244-473a-d78b-5c46d1140eba"
},
"source": [
"%cd /content"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4IyFL3MaiYRu",
"outputId": "b3eba535-018c-4bb0-f61a-081e4a812f3a"
},
"source": [
"!pwd"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DMw7QwvpiiUO",
"outputId": "6ba0d2c0-b245-4cf5-963e-8acf61e2f6f6"
},
"source": [
"# 检查 nvcc 版本\n",
"!nvcc -V"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"nvcc: NVIDIA (R) Cuda compiler driver\n",
"Copyright (c) 2005-2020 NVIDIA Corporation\n",
"Built on Mon_Oct_12_20:09:46_PDT_2020\n",
"Cuda compilation tools, release 11.1, V11.1.105\n",
"Build cuda_11.1.TC455_06.29190527_0\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4VIBU7Fain4D",
"outputId": "72fd879a-acb2-449f-904a-71e3df71edbd"
},
"source": [
"# 检查 GCC 版本\n",
"!gcc --version"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
"Copyright (C) 2017 Free Software Foundation, Inc.\n",
"This is free software; see the source for copying conditions. There is NO\n",
"warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
"\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "24lDLCqFisZ9",
"outputId": "9e7a84b7-7e9b-4508-a7e9-57a902cfae73"
},
"source": [
"# 检查 PyTorch 的安装情况\n",
"import torch, torchvision\n",
"print(torch.__version__)\n",
"print(torch.cuda.is_available())"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1.9.0+cu111\n",
"True\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "markdown",
"metadata": {
"id": "R2aZNLUwizBs"
},
"source": [
"### 安装 MMCV\n",
"\n",
"MMCV 是 OpenMMLab 代码库的基础库。Linux 环境的安装 whl 包已经提前打包好,大家可以直接下载安装。\n",
"\n",
"需要注意 PyTorch 和 CUDA 版本,确保能够正常安装。\n",
"\n",
"在前面的步骤中,我们输出了环境中 CUDA 和 PyTorch 的版本,分别是 11.1 和 1.9.0,我们需要选择相应的 MMCV 版本。\n",
"\n",
"另外,也可以安装完整版的 MMCV-full, 它包含所有的特性以及丰富的开箱即用的 CUDA 算子。需要注意的是完整版本可能需要更长时间来编译。"
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nla40LrLi7oo",
"outputId": "5373aef3-c65b-4b3a-f1bd-44f0ea015a4e"
},
"source": [
"# 安装 mmcv\n",
"!pip install mmcv -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html\n",
"# !pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in links: https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html\n",
"Requirement already satisfied: mmcv in /usr/local/lib/python3.7/dist-packages (1.3.14)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from mmcv) (7.1.2)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from mmcv) (21.0)\n",
"Requirement already satisfied: addict in /usr/local/lib/python3.7/dist-packages (from mmcv) (2.4.0)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from mmcv) (3.13)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcv) (1.19.5)\n",
"Requirement already satisfied: yapf in /usr/local/lib/python3.7/dist-packages (from mmcv) (0.31.0)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->mmcv) (2.4.7)\n"
]
}
2021-07-27 13:58:27 +08:00
]
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "markdown",
"metadata": {
"id": "GDTUrYvXjlRb"
},
"source": [
"### 克隆并安装 MMCls\n",
"\n",
"接着,我们从 github 上克隆下 mmcls 最新代码库并进行安装。"
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bwme6tWHjl5s",
"outputId": "bafc6818-20ea-47c3-f2de-7f89b2033e1b"
},
"source": [
"# 下载 mmcls 代码库\n",
"!git clone https://github.com/open-mmlab/mmclassification.git\n",
"%cd mmclassification/\n",
"\n",
"# 从源码安装 MMClassification\n",
"!pip install -e . "
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"fatal: destination path 'mmclassification' already exists and is not an empty directory.\n",
"/content/mmclassification\n",
"Obtaining file:///content/mmclassification\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (3.2.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (1.19.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from mmcls==0.16.0) (21.0)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (2.4.7)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (0.10.0)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (2.8.2)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.16.0) (1.3.2)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->mmcls==0.16.0) (1.15.0)\n",
"Installing collected packages: mmcls\n",
" Attempting uninstall: mmcls\n",
" Found existing installation: mmcls 0.16.0\n",
" Can't uninstall 'mmcls'. No files were found to uninstall.\n",
" Running setup.py develop for mmcls\n",
"Successfully installed mmcls-0.16.0\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hFg_oSG4j3zB",
"outputId": "3357ad97-fef6-4d3e-e343-8629bf4094dc"
},
"source": [
"# 检查 MMClassification 的安装情况\n",
"import mmcls\n",
"print(mmcls.__version__)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.16.0\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PkfqxfLIQVFM",
"outputId": "3a4a7dce-0ebc-44e9-aa0a-eb9975f51279"
},
"source": [
"# 下载预训练模型\n",
"!mkdir checkpoints\n",
"!wget https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth -P checkpoints"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"mkdir: cannot create directory ‘ checkpoints’ : File exists\n",
"--2021-10-11 08:18:28-- https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"Resolving download.openmmlab.com (download.openmmlab.com)... 47.252.96.35\n",
"Connecting to download.openmmlab.com (download.openmmlab.com)|47.252.96.35|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 14206911 (14M) [application/octet-stream]\n",
"Saving to: ‘ checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth.5’ \n",
"\n",
"mobilenet_v2_batch2 100%[===================>] 13.55M 7.65MB/s in 1.8s \n",
"\n",
"2021-10-11 08:18:31 (7.65 MB/s) - ‘ checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth.5’ saved [14206911/14206911]\n",
"\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XHCHnKb_Qd3P",
"outputId": "ba447c6e-e135-4813-b07e-9e624c14a45b"
},
"source": [
"# 下载分类数据集文件\n",
"!wget https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0 -O cats_dogs_dataset.zip\n",
"!mkdir data\n",
"!unzip -q cats_dogs_dataset.zip -d ./data/"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2021-10-11 08:18:31-- https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0\n",
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:601b:18::a27d:812\n",
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.\n",
"HTTP request sent, awaiting response... 301 Moved Permanently\n",
"Location: /s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip [following]\n",
"--2021-10-11 08:18:31-- https://www.dropbox.com/s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip\n",
"Reusing existing connection to www.dropbox.com:443.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://uc4cc3369ee3aa8a59d8e5baf143.dl.dropboxusercontent.com/cd/0/inline/BX20wcgVboqmCYLywEnuZlaxl50WNpSeRrmVoNJIwLyYpY8rl8nhAEbNV__ve9DJcXYZvw7on-Jt95gFhbek5DuwMxGT4d6nJbO9uofnVgMt8GaFD3Tsl7A33kuHVwRAGFsNpcNgElFfuUWo8AWsc54H/file# [following]\n",
"--2021-10-11 08:18:31-- https://uc4cc3369ee3aa8a59d8e5baf143.dl.dropboxusercontent.com/cd/0/inline/BX20wcgVboqmCYLywEnuZlaxl50WNpSeRrmVoNJIwLyYpY8rl8nhAEbNV__ve9DJcXYZvw7on-Jt95gFhbek5DuwMxGT4d6nJbO9uofnVgMt8GaFD3Tsl7A33kuHVwRAGFsNpcNgElFfuUWo8AWsc54H/file\n",
"Resolving uc4cc3369ee3aa8a59d8e5baf143.dl.dropboxusercontent.com (uc4cc3369ee3aa8a59d8e5baf143.dl.dropboxusercontent.com)... 162.125.3.15, 2620:100:601b:15::a27d:80f\n",
"Connecting to uc4cc3369ee3aa8a59d8e5baf143.dl.dropboxusercontent.com (uc4cc3369ee3aa8a59d8e5baf143.dl.dropboxusercontent.com)|162.125.3.15|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: /cd/0/inline2/BX0Ph6Weiw7Ymc6XPsTFW2VF_7IQgb546X7Hz1O0vc5e9CP4hsVC1taDbH2WWQu9ift-oTxQLk3OJHhwsmsgmkLW4aNQfDtZQ6TtOnRneXV3DtxxNMLOnFYCH5NTdt5RNzONmFkuRy9N11GBndC4_NDlqXqc3ctwoE_TVL0eM-ah25dcBpGEMvL-51yWxBfHYI5_nZXlgLaCAbGkVl3E3aVqTrVmorAmaHNCPD6sU8PnFlrXJnn6zoXP8UhiuvcUAhVqZ8EjRshto6vu2w08hbMv2U4Ax7DMY9jU5EGBqFbIL91bF3tPldNO7iGRkz-DfCkVblXPS2SeVRFWibjoZMZmXc3DPQSHTNLXTmDewko4lVjNg6vuNKr3ClJuo_LTEPQ/file [following]\n",
"--2021-10-11 08:18:32-- https://uc4cc3369ee3aa8a59d8e5baf143.dl.dropboxusercontent.com/cd/0/inline2/BX0Ph6Weiw7Ymc6XPsTFW2VF_7IQgb546X7Hz1O0vc5e9CP4hsVC1taDbH2WWQu9ift-oTxQLk3OJHhwsmsgmkLW4aNQfDtZQ6TtOnRneXV3DtxxNMLOnFYCH5NTdt5RNzONmFkuRy9N11GBndC4_NDlqXqc3ctwoE_TVL0eM-ah25dcBpGEMvL-51yWxBfHYI5_nZXlgLaCAbGkVl3E3aVqTrVmorAmaHNCPD6sU8PnFlrXJnn6zoXP8UhiuvcUAhVqZ8EjRshto6vu2w08hbMv2U4Ax7DMY9jU5EGBqFbIL91bF3tPldNO7iGRkz-DfCkVblXPS2SeVRFWibjoZMZmXc3DPQSHTNLXTmDewko4lVjNg6vuNKr3ClJuo_LTEPQ/file\n",
"Reusing existing connection to uc4cc3369ee3aa8a59d8e5baf143.dl.dropboxusercontent.com:443.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 228802825 (218M) [application/zip]\n",
"Saving to: ‘ cats_dogs_dataset.zip’ \n",
"\n",
"cats_dogs_dataset.z 100%[===================>] 218.20M 74.8MB/s in 2.9s \n",
"\n",
"2021-10-11 08:18:35 (74.8 MB/s) - ‘ cats_dogs_dataset.zip’ saved [228802825/228802825]\n",
"\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "markdown",
"metadata": {
"id": "e4t2P2aTQokX"
},
"source": [
"完成下载和解压之后, \"Cats and Dogs Dataset\" 文件夹下的文件结构如下:\n",
"```\n",
"data/cats_dogs_dataset\n",
"├── classes.txt\n",
"├── test.txt\n",
"├── val.txt\n",
"├── training_set\n",
"│ ├── training_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.1.jpg\n",
"│ │ │ ├── cat.2.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.2.jpg\n",
"│ │ │ ├── dog.3.jpg\n",
"│ │ │ ├── ...\n",
"├── val_set\n",
"│ ├── val_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.3.jpg\n",
"│ │ │ ├── cat.5.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.1.jpg\n",
"│ │ │ ├── dog.6.jpg\n",
"│ │ │ ├── ...\n",
"├── test_set\n",
"│ ├── test_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.4001.jpg\n",
"│ │ │ ├── cat.4002.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.4001.jpg\n",
"│ │ │ ├── dog.4002.jpg\n",
"│ │ │ ├── ...\n",
"```\n",
"\n",
"可以通过 shell 命令 `tree data/cats_dogs_dataset` 查看文件结构。"
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 297
},
"id": "46tyHTdtQy_Z",
"outputId": "000012ac-01da-4294-f997-058d46470667"
},
"source": [
"# 获取一张图像可视化\n",
"from PIL import Image\n",
"Image.open('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')"
],
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEYCAIAAABp9FyZAAEAAElEQVR4nJT96ZMkWXIfCKrqe88uv+KOvCuzju7qrmo00A00QZAcoZAznPmwIkP+p/NhZWVnZTAHSBALAmig+u6ursrKMzIyDj/teofqflB3S89qckTWpCTKM8Lc3ezZ0+unP1XF737+KQAgGEQDQCAoIgD04MGDul7f3Nx0fTsalbPZJCa/XC5dOQ4hGGMAoGka55y1tu/7LMu89865zWbz+PHj+Xy+2WxGo1G3rokIAJgZEa3N9L3C4Jwzxngfu64TkbIsq6pqu40xJqXU922MkYiMMSKS56Uh61zubGmtM5QjGkTMDZRl2fd9CKGoSkTx3gthSgkRBCGllFIKnAAAEbGXPM/1q7vOM3OM0ZBLKcWYUkrM/PHH33n58mWe5yklnlW3t7e5M6fHJ7PRaLOY3769zAxNsvz06PDJwwe5ofnV1eNHD2+uru7du3O7aeu6Nha97/7n//l/fvHimQ8dIi4Wi/V6vVk34/H0+voa0Xzy8XdfvXrVtt39+/c//vjjt2/fPn32TZ7nDBKZY4x37997+/btn/zpj9++fftX//E//vCHP1yv1yWS98FaK2i+//3vv3nztutDluWcQMjEwNfXN/cfPHrz5o0IVlVlQvfpp5+mlH72s5+NZ9MY48nZ6Xy5XK1Wn37/e3fu3Hnx6uVvf/vbPoQ7d+6cn5/317ezg8lqteq6phqVV1eXm83q0aNHTdOklFKSlASBrM2yLHM2ny9eiAgB5nleFBWhjZG7kMAUybibTXdxs1yHaMspOtf7aAiMMdZaRERhRHTOOUNd13FKIfQxRokJUay1xpibxTzGKCLWWiJCkizL8jy/vb1JKT169Kiu133fHx8f397etm07nU67rqubdex9nueTycQY5JjeXF4cTmfHx8fG4Pz2drVaTSaTO3furNb1YrHw3mdZllKq6xoRx+NxWZbM7JyLMW42G+dcVVUppc1mlee5916/6/T0tGmaZ8+e3bt3DxHLsgSAvu+dcyIyn8/J8PHxIVFs13Mi/z/+9/9qOrWzcX7n7lEKTQi9QLKICAAg8K2j7/uUEiKKSAih67rEIYRQ94uiKEII0+n06urq5OQkhAAAMUa9iJQSAFRVZa1dr9ciwszGGN36RDal5L1HIGYGABHRaxARXe5vXwrAVoaBmVl2B4CIiE/RGOO9DyGQNYMQhhCICA3q240xiGj0f4jW2rqui6Lquq4sS9/HoiiOjo5fvXo1mUx+/OMf379//2/+5m8ePnz4y9fPfvSjH/3t3/y1M/aPP//8F/+0qKpqeXuTCSBi0zTZeDSbzbquy/P86uqqmB5WVeUyU9dS1/XV1VVMvqqqs7Oz5XK5XC6rajydTlerze3tbUppPB4bY4qi0KUmIu9749xqtdKlOz8/J6I/+7M/CyE450Lb6foj4dOnT7OsCCGEEIu8stYSwsHBASBba5m5qoo337yq6/WdO3em0/Hd+3frun79+uXZ2Vni8PTpVwcH0zx3JydHTdN88+zrsspzgL7vjTHT6bQo87atQ+ibpgkhWGuzLBPBGFIIqWkakM7lGcfEzMzsvUdIPiYfOJ9UQGSMIWsNCxElkARCAqr4AACFRcR7j8LMzCnF6FNKFinLbFEUWZYdnZ40TdN1XQghhOBD1/e9bjYVifV6HUIYjUZElBeubVuBlGWZI6OLEH3o+x4AyrIUkfl8EUI4Pj4moouLCxa01qrCreu6russy0RkPB7Xda17O8syIkopxRiZue97773esu7YsixVjJ1zRDRsY2stS4wxWgPWWmt1JwMiJh9i4pQSS6Rhl8v7h/ceEVVy9r9YHwYijkYjABiNRnoPw14XEbWBjx8/zvNcf6NXP0ia3ttwV2ru9CnuiyUi6tbUW8L3D/3GuDtCivsHMydhERARRgAAIkLEBNLH4FNcbtYhJUHMigKtuV0uNm2zbuou+Murq+9//rnNsuvb2zzPz8/PP//88+l0aoxxzumNq6LRVTo6Ouq67ujoaL1e3717dzwep5SyLBuNRrPZjJlXq9VqtXLO6T7I81y1qd71crns+17dAedc13UoYoy5vLw8OTl5/vz5D37wg/Ozs6PDwzzPmVkfZmbp5cuXKSUU0MtAxLars8xJ4uOTQ/VN7t69+0//9E+67Z49e6ZKJ8Z4cHDw4ZMnP//5z5tNrdvx0f0Hq/miqde31zd92xhCjoEADdJ6ueqaNoVokMo8K/LMEkYfurZGREEQER9DCMHHwMyM0HVt27ad72OMMaXOt13X6T3GGL33IXh1Urz3XddlWWYt7R7rdqsw87Nnz54/f/7ixYu3b99u6lVKyVqbZS7LsizLdFPpNiADRJQ4OOemo/F0Oi2KAoD188fjMQC0bRtDyLJMDd1yuby+vo4xZllmjCGi8Xh8dnZ2584dAPDe13UdY9RtrIrDWqsbTLdTCAERp9Op6ik9Qfe8un4JpO9DCMHaLM9LvVRh8N7HGEUQwZCezcxpdwwibq0djUZlWaoa0C/QbeScCyGMx2MRybJMnQoR6ft+vV5fXV0h4ocffnh6ejoajVSRqIx1Xafr+L6PalWwU0qDJRzEbBBLRCQilVh98U4g3700QoiI+pCstcZYEBSGFDlF7r0XgN57Y20fvMszm7nRZFxUZe/94dFRluf/5e//7uPvfPKTP/9n63pzfHz8/Pnzf/2v//WDBw9+9atf6dqpsiyKIqWkznnTNKenp9PpVNVn0zSbzeby8lK90BDCb37zG2vt2dlZ3/d1XRdFoU+rbdurq6vlcpmEp9Pp0dGRcw4AiqLo2857//vf/q5erfURTMbjsiy7rh3WIYTgMiPCSGJAdspbPnz85OT0qOub2XTcdvXzF98kDin0PnQI7H2XGTo7OyGU2/n1Yn5jCI4OZ02z6ft+uVyq8VksFl3XuMxYR71v62a92azUThZFUZalc86nkFIMKaaUYkqIaDNXlCUQCoIxJitcVhY2z4xzLs8QkZm93+p0IrKWdD/oHTnnNMbRM+/cuXN0dKSO1aCRdds4ZxCxKAqXGR+6EEJd11VejMtK/UlriQCNMS6zR7ODvu/bth6NRqPRqGmaul7nuauqSpcxxjgej+/fv3/37t3pdHp9fb1er9u2TSk553RzWmuLotDtp0ZPhVDV7mCo9F7UzAij97Hvg6AGVsAhMbP3UZIYpNwVpHel4qcirnKoFl9XRKVFhTDP89VqhYjz+Xwymcznc0TUB6OGmJnrum6aRgVYTZBzTkVCpUtvYJCbQa50lQch3JdD3h3f0hr6sIczhRCAEI2xmTHGGAN7QktEYMgVeQKZHh4kEJtnPsWsKNQe5lXJCJfXV6t644ocDLVtu1gs1Dm/vr6u63qz2YiIOkjOOdU7RHRwcPDJJ58Q0ZMnTz755JOyLL/88stvvvkGAL773e/OZjO9AFVn6l8YY5yzIty2zXg8un//3tnZ6cHBjAhR+PDo4JunX2eZ+9//978cFcXLZ8/KLJtOp23b6uecHh2v18uUUgpxvV7H5HNrEofet5PJ+N6dO9Px+M2b13/ywz/65puvr64u//zPf3JyeHB5eVGV+Xqz/Onf/5c//qPPLSGBfP797/3iZ19Yg9aStSTAXdfU9TrGOBqNzs/PsyyLMa5Wq+Vy7n3nMluWeVnmiCiEYBAIhRhIyFrnXDmqqqoqRuVkNj04OJhOp7OD6eHxgbWGCNWl6vs+hF631mJ5u9lsvPcAbK3VHQIAR0dHBwcHk+moKDPdIQCit8/MbdsmDjHGuq5VIZZlqVYhRh99UIMxLisAcM6URWGtTSl436FAmRcff/zxeDxWpVMUxWQySSldXl7e3Nzoe3UXqQEnIjUY+vtB8FRpMrPamGG3ExGA+rEphMBJhJETxMgxcIoCQIho33dHh+BQ+r4vikz3SkoJgMkYRHRZpjAMEZ2cnCyXyyzLxuNxjHEw6HVdLxYL9bJ0rVXwjDEpSQhBRPKsGELBwdskIpb0LXdUzwkhgKAxKaWESAxRBEXEUhRJzFFEEggyi0hiF
"text/plain": [
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x280 at 0x7F8B4EEF2F90>"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "My5Z6p7pQ3UC"
},
"source": [
"### 支持新的数据集\n",
"\n",
"MMClassification 要求数据集必须将图像和标签放在同级目录下。有两种方式可以支持自定义数据集。\n",
"\n",
"最简单的方式就是将数据集转换成现有的数据集格式(比如 ImageNet) 。另一种方式就是新建一个新的数据集类。细节可以查看 [文档](https://github.com/open-mmlab/mmclassification/blob/master/docs_zh-CN/tutorials/new_dataset.md).\n",
"\n",
"在这个教程中,为了方便学习,我们已经将 “猫狗分类数据集” 按照 ImageNet 的数据集格式进行了整理。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P335gKt9Q5U-"
},
"source": [
"标准文件包括:\n",
"\n",
"1. 类别列表。每行代表一个类别。\n",
" ```\n",
" cats\n",
" dogs\n",
" ```\n",
"2. 训练/验证/测试标签。\n",
"每行包括一个文件名和其相对应的标签。\n",
" ```\n",
" ...\n",
" cats/cat.3769.jpg 0\n",
" cats/cat.882.jpg 0\n",
" ...\n",
" dogs/dog.3881.jpg 1\n",
" dogs/dog.3377.jpg 1\n",
" ...\n",
" ```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BafQ7ijBQ8N_"
},
"source": [
"## 命令行工具的使用\n",
"\n",
"MMCls 同样提供了命令行工具,提供如下功能:\n",
"\n",
"1. 模型训练\n",
"2. 模型微调\n",
"3. 模型测试\n",
"4. 推理计算\n",
"\n",
"模型训练的过程与模型微调的过程一致,我们已经看到 Python API 的推理和模型微调过程。接下来我们将会看到如何使用命令行工具完成这些任务。更过细节可以参考 [文档](https://github.com/open-mmlab/mmclassification/blob/master/docs_zh-CN/getting_started.md)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Aj5cGMihURrZ"
},
"source": [
"### 模型微调\n",
"\n",
"通过命令行进行模型微调步骤如下:\n",
"\n",
"1. 准备自定义数据集\n",
"2. 数据集适配 MMCls 要求\n",
"3. 在 py 脚本中修改配置文件\n",
"4. 使用命令行工具进行模型微调\n",
"\n",
"第1, 2步与之前的介绍一致, 我们将会介绍后面2个步骤的内容。\n",
"\n",
"#### 在 py 脚本中修改配置文件\n",
"\n",
"为了能够复用不同配置文件中常用的部分,我们支持多配置文件继承。比如模型微调 MobileNetV2 ,新的配置文件可以通过继承 `configs/_base_/models/mobilenet_v2_1x.py` 来创建模型的基本结构。 继承 `configs/_base_/datasets/cats_dogs_dataset.py` 来使用之前定义好的数据集。继承 `configs/_base_/schedules/cats_dogs_finetune.py` 来自定义学习率策略。为了能够运行设定的学习率策略,还需要继承 `configs/_base_/default_runtime.py`.\n",
"\n",
"最后的配置文件应该显示如下\n",
"\n",
"```\n",
"# Save to \"configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\"\n",
"_base_ = [\n",
" '../_base_/models/mobilenet_v2_1x.py',\n",
" '../_base_/datasets/imagenet_bs32_pil_resize.py',\n",
" '../_base_/schedules/imagenet_bs256_epochstep.py',\n",
" '../_base_/default_runtime.py'\n",
"]\n",
"```\n",
"\n",
"此外,也可以不使用这种继承的方式,而直接构建完整的配置文件,比如 `configs/mnist/lenet5.py`.\n",
"\n",
"这里我们使用了重构好的数据集,如果想要完全使用自定义的数据集,还需要重新构建一个数据集配置,这个配置会覆盖之前的内容。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_UV3oBhLRG8B"
},
"source": [
"首先,修改模型配置并保存为 `configs/_base_/models/mobilenet_v2_1x_cats_dogs.py`。这个新的配置文件需要根据分类问题的类别来调整模型 `head` 的 `num_classes`。预训练模型的权重,除了最后一层线性层,其他的部分一般选择复用。"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8QfM4qBeWIQh",
"outputId": "745b9519-88f0-4dda-c510-ef324886509e"
},
"source": [
"%%writefile configs/_base_/models/mobilenet_v2_1x_cats_dogs.py\n",
"_base_ = ['./mobilenet_v2_1x.py']\n",
"model = dict(\n",
" backbone=dict(\n",
" init_cfg = dict(\n",
" type='Pretrained', \n",
" checkpoint='checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', \n",
" prefix='backbone')\n",
" ),\n",
" head=dict(\n",
" num_classes=2,\n",
" topk = (1, )\n",
" ))"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Overwriting configs/_base_/models/mobilenet_v2_1x_cats_dogs.py\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F2bjgpsZRKp1"
},
"source": [
"第二,数据配置,保存为 `configs/_base_/datasets/cats_dogs_dataset.py`."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DMIp07L4Wn80",
"outputId": "d09c386e-a61a-4e47-cff4-0fbf6bb89d23"
},
"source": [
"%%writefile configs/_base_/datasets/cats_dogs_dataset.py\n",
"_base_ = ['./imagenet_bs32.py']\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.050, 106.438],\n",
" std=[58.577, 57.310, 57.437],\n",
" to_rgb=True)\n",
"\n",
"data = dict(\n",
" # 每个 gpu 上的 batch size 和 num_workers 设置,根据计算机情况设置\n",
" samples_per_gpu = 32,\n",
" workers_per_gpu=2,\n",
" # 指定训练集路径\n",
" train = dict(\n",
" data_prefix = 'data/cats_dogs_dataset/training_set/training_set',\n",
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
" ),\n",
" # 指定验证集路径\n",
" val = dict(\n",
" data_prefix = 'data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file = 'data/cats_dogs_dataset/val.txt',\n",
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
" ),\n",
" # 指定测试集路径\n",
" test = dict(\n",
" data_prefix = 'data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file = 'data/cats_dogs_dataset/test.txt',\n",
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
" )\n",
")\n",
"# 修改评估指标设置\n",
"evaluation = dict(metric_options={'topk': (1, )})"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Overwriting configs/_base_/datasets/cats_dogs_dataset.py\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_lxAl1cSRM_D"
},
"source": [
"第三是学习率策略。模型微调的策略与默认策略差别很大。微调一般会要求更小的学习率和更少的训练周期。最后保存为 `configs/_base_/schedules/cats_dogs_finetune.py`."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6-JTFNaDWzFQ",
"outputId": "6cd2a5a6-c6c8-4649-d2a9-61852691d847"
},
"source": [
"%%writefile configs/_base_/schedules/cats_dogs_finetune.py\n",
"# 优化器设置\n",
"# 设定针对 batch size 为 128 的学习率\n",
"optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"# 学习率策略\n",
"lr_config = dict(policy='step', step=[1])\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Overwriting configs/_base_/schedules/cats_dogs_finetune.py\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ofZoBfseROf1"
},
"source": [
"最后,运行环境配置。直接使用默认的配置。我们将上述所有修改和保存的配置文件集中到一个文件中,并保存为 `configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py`.\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3tp9C42uXgRD",
"outputId": "f7c30bc5-5338-4677-98f4-539cdf35b5e5"
},
"source": [
"%%writefile configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\n",
"_base_ = [\n",
" '../_base_/models/mobilenet_v2_1x_cats_dogs.py', '../_base_/datasets/cats_dogs_dataset.py',\n",
" '../_base_/schedules/cats_dogs_finetune.py', '../_base_/default_runtime.py'\n",
"]"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Overwriting configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "chLX7bL3RP2F"
},
"source": [
"#### 使用命令行进行模型微调\n",
"\n",
"我们使用 `tools/train.py` 进行模型微调:\n",
"\n",
"```\n",
"python tools/train.py ${CONFIG_FILE} [optional arguments]\n",
"```\n",
"\n",
"如果你希望指定训练过程中相关文件的保存位置,可以增加一个参数 `--work_dir ${YOUR_WORK_DIR}`.\n",
"\n",
"通过增加参数 `--seed ${SEED}`,设置随机种子以保证结果的可重复性,而参数 `--deterministic`则会启用 cudnn 的确定性选项,进一步保证可重复性,但可能降低些许效率。\n",
"\n",
"这里我们使用 `MobileNetV2` 和数据集 `CatsDogsDataset` 作为示例"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gbFGR4SBRUYN",
"outputId": "82b288da-c0d5-4230-a611-5b403469d11e"
},
"source": [
"!python tools/train.py \\\n",
" configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py \\\n",
" --work-dir work_dirs/mobilenet_v2_1x_cats_dogs \\\n",
" --seed 0 \\\n",
" --deterministic"
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"2021-10-11 08:18:45,882 - mmcls - INFO - Environment info:\n",
"------------------------------------------------------------\n",
"sys.platform: linux\n",
"Python: 3.7.12 (default, Sep 10 2021, 00:21:48) [GCC 7.5.0]\n",
"CUDA available: True\n",
"GPU 0: Tesla K80\n",
"CUDA_HOME: /usr/local/cuda\n",
"NVCC: Build cuda_11.1.TC455_06.29190527_0\n",
"GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
"PyTorch: 1.9.0+cu111\n",
"PyTorch compiling details: PyTorch built with:\n",
" - GCC 7.3\n",
" - C++ Version: 201402\n",
" - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n",
" - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n",
" - OpenMP 201511 (a.k.a. OpenMP 4.5)\n",
" - NNPACK is enabled\n",
" - CPU capability usage: AVX2\n",
" - CUDA Runtime 11.1\n",
" - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n",
" - CuDNN 8.0.5\n",
" - Magma 2.5.2\n",
" - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n",
"\n",
"TorchVision: 0.10.0+cu111\n",
"OpenCV: 4.1.2\n",
"MMCV: 1.3.14\n",
"MMCV Compiler: n/a\n",
"MMCV CUDA Compiler: n/a\n",
"MMClassification: 0.16.0+6fba107\n",
"------------------------------------------------------------\n",
"\n",
"2021-10-11 08:18:45,882 - mmcls - INFO - Distributed training: False\n",
"2021-10-11 08:18:46,128 - mmcls - INFO - Config:\n",
"model = dict(\n",
" type='ImageClassifier',\n",
" backbone=dict(\n",
" type='MobileNetV2',\n",
" widen_factor=1.0,\n",
" init_cfg=dict(\n",
" type='Pretrained',\n",
" checkpoint=\n",
" 'checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth',\n",
" prefix='backbone')),\n",
" neck=dict(type='GlobalAveragePooling'),\n",
" head=dict(\n",
" type='LinearClsHead',\n",
" num_classes=2,\n",
" in_channels=1280,\n",
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
" topk=(1, )))\n",
"dataset_type = 'ImageNet'\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.05, 106.438], std=[58.577, 57.31, 57.437], to_rgb=True)\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
"]\n",
"data = dict(\n",
" samples_per_gpu=32,\n",
" workers_per_gpu=2,\n",
" train=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'),\n",
" val=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file='data/cats_dogs_dataset/val.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'),\n",
" test=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file='data/cats_dogs_dataset/test.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'))\n",
"evaluation = dict(\n",
" interval=1, metric='accuracy', metric_options=dict(topk=(1, )))\n",
"optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"lr_config = dict(policy='step', step=[1])\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"checkpoint_config = dict(interval=1)\n",
"log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])\n",
"dist_params = dict(backend='nccl')\n",
"log_level = 'INFO'\n",
"load_from = None\n",
"resume_from = None\n",
"workflow = [('train', 1)]\n",
"work_dir = 'work_dirs/mobilenet_v2_1x_cats_dogs'\n",
"gpu_ids = range(0, 1)\n",
"\n",
"2021-10-11 08:18:46,128 - mmcls - INFO - Set random seed to 0, deterministic: True\n",
"2021-10-11 08:18:46,212 - mmcls - INFO - initialize MobileNetV2 with init_cfg {'type': 'Pretrained', 'checkpoint': 'checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', 'prefix': 'backbone'}\n",
"2021-10-11 08:18:46,212 - mmcv - INFO - load backbone in model from: checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"Use load_from_local loader\n",
"2021-10-11 08:18:46,258 - mmcls - INFO - initialize LinearClsHead with init_cfg {'type': 'Normal', 'layer': 'Linear', 'std': 0.01}\n",
"2021-10-11 08:18:48,789 - mmcls - INFO - Start running, host: root@779e8ce3556b, work_dir: /content/mmclassification/work_dirs/mobilenet_v2_1x_cats_dogs\n",
"2021-10-11 08:18:48,790 - mmcls - INFO - Hooks will be executed in the following order:\n",
"before_run:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(NORMAL ) CheckpointHook \n",
"(NORMAL ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_train_epoch:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(NORMAL ) EvalHook \n",
"(LOW ) IterTimerHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_train_iter:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(NORMAL ) EvalHook \n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_train_iter:\n",
"(ABOVE_NORMAL) OptimizerHook \n",
"(NORMAL ) CheckpointHook \n",
"(NORMAL ) EvalHook \n",
"(LOW ) IterTimerHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"after_train_epoch:\n",
"(NORMAL ) CheckpointHook \n",
"(NORMAL ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_val_epoch:\n",
"(LOW ) IterTimerHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_val_iter:\n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_val_iter:\n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_val_epoch:\n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"2021-10-11 08:18:48,790 - mmcls - INFO - workflow: [('train', 1)], max: 2 epochs\n",
"2021-10-11 08:19:19,160 - mmcls - INFO - Epoch [1][100/201]\tlr: 1.000e-02, eta: 0:01:31, time: 0.302, data_time: 0.024, memory: 1709, loss: 0.9694\n",
"2021-10-11 08:19:46,980 - mmcls - INFO - Epoch [1][200/201]\tlr: 1.000e-02, eta: 0:00:58, time: 0.278, data_time: 0.002, memory: 1709, loss: 0.6289\n",
"2021-10-11 08:19:47,026 - mmcls - INFO - Saving checkpoint at 1 epochs\n",
"[ ] 0/1601, elapsed: 0s, ETA:[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)\n",
"[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)\n",
"[>>] 1601/1601, 215.8 task/s, elapsed: 7s, ETA: 0s2021-10-11 08:19:54,550 - mmcls - INFO - Epoch(val) [1][51]\taccuracy_top-1: 76.7021\n",
"2021-10-11 08:20:24,188 - mmcls - INFO - Epoch [2][100/201]\tlr: 1.000e-03, eta: 0:00:29, time: 0.295, data_time: 0.023, memory: 1709, loss: 0.4870\n",
"2021-10-11 08:20:51,578 - mmcls - INFO - Epoch [2][200/201]\tlr: 1.000e-03, eta: 0:00:00, time: 0.274, data_time: 0.002, memory: 1709, loss: 0.4671\n",
"2021-10-11 08:20:51,619 - mmcls - INFO - Saving checkpoint at 2 epochs\n",
"[>>] 1601/1601, 218.4 task/s, elapsed: 7s, ETA: 0s2021-10-11 08:20:59,051 - mmcls - INFO - Epoch(val) [2][51]\taccuracy_top-1: 83.6352\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m_ZSkwB5Rflb"
},
"source": [
"### 测试模型\n",
"\n",
"使用 `tools/test.py` 对模型进行测试:\n",
"\n",
"```\n",
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]\n",
"```\n",
"\n",
"这里有一些可选参数可以进行配置:\n",
"\n",
"- `--metrics`: 评价方式,这依赖于数据集,比如准确率 a\n",
"- `--metric-options`: 对于评估过程的自定义操作,如 topk=1.\n",
"\n",
"更多细节请参看 `tools.test.py` 。\n",
"\n",
"这里依然使用示例 `MobileNetV2`."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zd4EM00QRtyc",
"outputId": "45b23cc5-f12e-4ccf-d375-2cf8a320bb0b"
},
"source": [
"!python tools/test.py configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py work_dirs/mobilenet_v2_1x_cats_dogs/latest.pth --metrics=accuracy --metric-options=topk=1"
],
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"Use load_from_local loader\n",
"[>>] 2023/2023, 218.8 task/s, elapsed: 9s, ETA: 0s\n",
"accuracy : 83.54\n"
]
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "markdown",
"metadata": {
"id": "IwThQkjaRwF7"
},
"source": [
"### 推理计算\n",
"\n",
"我们用下面的命令进行推理计算并保存计算结果。\n",
"\n",
"```shell\n",
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}]\n",
"```\n",
"\n",
"可选参数:\n",
"\n",
"- `RESULT_FILE`: 输出结果的文件名。如果不指定, 计算结果不会被保存。支持的格式包括json, pkl 和 yml\n",
"\n",
"这里依然使用示例 `MobileNetV2`."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6GVKloPHR0Fn",
"outputId": "efd9ff34-f91c-4a06-d0ba-e76d9cee6a4e"
},
"source": [
"!python tools/test.py configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py work_dirs/mobilenet_v2_1x_cats_dogs/latest.pth --out=results.json"
],
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:28: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
" return f(*args, **kwds)\n",
"/usr/local/lib/python3.7/dist-packages/yaml/constructor.py:126: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
" if not isinstance(key, collections.Hashable):\n",
"Use load_from_local loader\n",
"[>>] 2023/2023, 219.2 task/s, elapsed: 9s, ETA: 0s\n",
"dumping results to results.json\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G0NJI1s6e3FD"
},
"source": [
"导出的json 文件中保存了所有样本的推理结果、分类结果和分类得分"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 365
},
"id": "HJdJeLUafFhX",
"outputId": "2ee6762a-9289-4b62-b884-b1eeb933ca85"
},
"source": [
"import json\n",
"\n",
"with open(\"./results.json\", 'r') as f:\n",
" results = json.load(f)\n",
"\n",
"# 展示第一张图片的结果信息\n",
"print('class_scores:', results['class_scores'][0])\n",
"print('pred_class:', results['pred_class'][0])\n",
"print('pred_label:', results['pred_label'][0])\n",
"print('pred_score:', results['pred_score'][0])\n",
"Image.open('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"class_scores: [0.9462895393371582, 0.0537104494869709]\n",
"pred_class: cats\n",
"pred_label: 0\n",
"pred_score: 0.9462895393371582\n"
]
},
{
"output_type": "execute_result",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEYCAIAAABp9FyZAAEAAElEQVR4nJT96ZMkWXIfCKrqe88uv+KOvCuzju7qrmo00A00QZAcoZAznPmwIkP+p/NhZWVnZTAHSBALAmig+u6ursrKMzIyDj/teofqflB3S89qckTWpCTKM8Lc3ezZ0+unP1XF737+KQAgGEQDQCAoIgD04MGDul7f3Nx0fTsalbPZJCa/XC5dOQ4hGGMAoGka55y1tu/7LMu89865zWbz+PHj+Xy+2WxGo1G3rokIAJgZEa3N9L3C4Jwzxngfu64TkbIsq6pqu40xJqXU922MkYiMMSKS56Uh61zubGmtM5QjGkTMDZRl2fd9CKGoSkTx3gthSgkRBCGllFIKnAAAEbGXPM/1q7vOM3OM0ZBLKcWYUkrM/PHH33n58mWe5yklnlW3t7e5M6fHJ7PRaLOY3769zAxNsvz06PDJwwe5ofnV1eNHD2+uru7du3O7aeu6Nha97/7n//l/fvHimQ8dIi4Wi/V6vVk34/H0+voa0Xzy8XdfvXrVtt39+/c//vjjt2/fPn32TZ7nDBKZY4x37997+/btn/zpj9++fftX//E//vCHP1yv1yWS98FaK2i+//3vv3nztutDluWcQMjEwNfXN/cfPHrz5o0IVlVlQvfpp5+mlH72s5+NZ9MY48nZ6Xy5XK1Wn37/e3fu3Hnx6uVvf/vbPoQ7d+6cn5/317ezg8lqteq6phqVV1eXm83q0aNHTdOklFKSlASBrM2yLHM2ny9eiAgB5nleFBWhjZG7kMAUybibTXdxs1yHaMspOtf7aAiMMdZaRERhRHTOOUNd13FKIfQxRokJUay1xpibxTzGKCLWWiJCkizL8jy/vb1JKT169Kiu133fHx8f397etm07nU67rqubdex9nueTycQY5JjeXF4cTmfHx8fG4Pz2drVaTSaTO3furNb1YrHw3mdZllKq6xoRx+NxWZbM7JyLMW42G+dcVVUppc1mlee5916/6/T0tGmaZ8+e3bt3DxHLsgSAvu+dcyIyn8/J8PHxIVFs13Mi/z/+9/9qOrWzcX7n7lEKTQi9QLKICAAg8K2j7/uUEiKKSAih67rEIYRQ94uiKEII0+n06urq5OQkhAAAMUa9iJQSAFRVZa1dr9ciwszGGN36RDal5L1HIGYGABHRaxARXe5vXwrAVoaBmVl2B4CIiE/RGOO9DyGQNYMQhhCICA3q240xiGj0f4jW2rqui6Lquq4sS9/HoiiOjo5fvXo1mUx+/OMf379//2/+5m8ePnz4y9fPfvSjH/3t3/y1M/aPP//8F/+0qKpqeXuTCSBi0zTZeDSbzbquy/P86uqqmB5WVeUyU9dS1/XV1VVMvqqqs7Oz5XK5XC6rajydTlerze3tbUppPB4bY4qi0KUmIu9749xqtdKlOz8/J6I/+7M/CyE450Lb6foj4dOnT7OsCCGEEIu8stYSwsHBASBba5m5qoo337yq6/WdO3em0/Hd+3frun79+uXZ2Vni8PTpVwcH0zx3JydHTdN88+zrsspzgL7vjTHT6bQo87atQ+ibpgkhWGuzLBPBGFIIqWkakM7lGcfEzMzsvUdIPiYfOJ9UQGSMIWsNCxElkARCAqr4AACFRcR7j8LMzCnF6FNKFinLbFEUWZYdnZ40TdN1XQghhOBD1/e9bjYVifV6HUIYjUZElBeubVuBlGWZI6OLEH3o+x4AyrIUkfl8EUI4Pj4moouLCxa01qrCreu6russy0RkPB7Xda17O8syIkopxRiZue97773esu7YsixVjJ1zRDRsY2stS4wxWgPWWmt1JwMiJh9i4pQSS6Rhl8v7h/ceEVVy9r9YHwYijkYjABiNRnoPw14XEbWBjx8/zvNcf6NXP0ia3ttwV2ru9CnuiyUi6tbUW8L3D/3GuDtCivsHMydhERARRgAAIkLEBNLH4FNcbtYhJUHMigKtuV0uNm2zbuou+Murq+9//rnNsuvb2zzPz8/PP//88+l0aoxxzumNq6LRVTo6Ouq67ujoaL1e3717dzwep5SyLBuNRrPZjJlXq9VqtXLO6T7I81y1qd71crns+17dAedc13UoYoy5vLw8OTl5/vz5D37wg/Ozs6PDwzzPmVkfZmbp5cuXKSUU0MtAxLars8xJ4uOTQ/VN7t69+0//9E+67Z49e6ZKJ8Z4cHDw4ZMnP//5z5tNrdvx0f0Hq/miqde31zd92xhCjoEADdJ6ueqaNoVokMo8K/LMEkYfurZGREEQER9DCMHHwMyM0HVt27ad72OMMaXOt13X6T3GGL33IXh1Urz3XddlWWYt7R7rdqsw87Nnz54/f/7ixYu3b99u6lVKyVqbZS7LsizLdFPpNiADRJQ4OOemo/F0Oi2KAoD188fjMQC0bRtDyLJMDd1yuby+vo4xZllmjCGi8Xh8dnZ2584dAPDe13UdY9RtrIrDWqsbTLdTCAERp9Op6ik9Qfe8un4JpO9DCMHaLM9LvVRh8N7HGEUQwZCezcxpdwwibq0djUZlWaoa0C/QbeScCyGMx2MRybJMnQoR6ft+vV5fXV0h4ocffnh6ejoajVSRqIx1Xafr+L6PalWwU0qDJRzEbBBLRCQilVh98U4g3700QoiI+pCstcZYEBSGFDlF7r0XgN57Y20fvMszm7nRZFxUZe/94dFRluf/5e//7uPvfPKTP/9n63pzfHz8/Pnzf/2v//WDBw9+9atf6dqpsiyKIqWkznnTNKenp9PpVNVn0zSbzeby8lK90BDCb37zG2vt2dlZ3/d1XRdFoU+rbdurq6vlcpmEp9Pp0dGRcw4AiqLo2857//vf/q5erfURTMbjsiy7rh3WIYTgMiPCSGJAdspbPnz85OT0qOub2XTcdvXzF98kDin0PnQI7H2XGTo7OyGU2/n1Yn5jCI4OZ02z6ft+uVyq8VksFl3XuMxYR71v62a92azUThZFUZalc86nkFIMKaaUYkqIaDNXlCUQCoIxJitcVhY2z4xzLs8QkZm93+p0IrKWdD/oHTnnNMbRM+/cuXN0dKSO1aCRdds4ZxCxKAqXGR+6EEJd11VejMtK/UlriQCNMS6zR7ODvu/bth6NRqPRqGmaul7nuauqSpcxxjgej+/fv3/37t3pdHp9fb1er9u2TSk553RzWmuLotDtp0ZPhVDV7mCo9F7UzAij97Hvg6AGVsAhMbP3UZIYpNwVpHel4qcirnKoFl9XRKVFhTDP89VqhYjz+Xwymcznc0TUB6OGmJnrum6aRgVYTZBzTkVCpUtvYJCbQa50lQch3JdD3h3f0hr6sIczhRCAEI2xmTHGGAN7QktEYMgVeQKZHh4kEJtnPsWsKNQe5lXJCJfXV6t644ocDLVtu1gs1Dm/vr6u63qz2YiIOkjOOdU7RHRwcPDJJ58Q0ZMnTz755JOyLL/88stvvvkGAL773e/OZjO9AFVn6l8YY5yzIty2zXg8un//3tnZ6cHBjAhR+PDo4JunX2eZ+9//978cFcXLZ8/KLJtOp23b6uecHh2v18uUUgpxvV7H5HNrEofet5PJ+N6dO9Px+M2b13/ywz/65puvr64u//zPf3JyeHB5eVGV+Xqz/Onf/5c//qPPLSGBfP797/3iZ19Yg9aStSTAXdfU9TrGOBqNzs/PsyyLMa5Wq+Vy7n3nMluWeVnmiCiEYBAIhRhIyFrnXDmqqqoqRuVkNj04OJhOp7OD6eHxgbWGCNWl6vs+hF631mJ5u9lsvPcAbK3VHQIAR0dHBwcHk+moKDPdIQCit8/MbdsmDjHGuq5VIZZlqVYhRh99UIMxLisAcM6URWGtTSl436FAmRcff/zxeDxWpVMUxWQySSldXl7e3Nzoe3UXqQEnIjUY+vtB8FRpMrPamGG3ExGA+rEphMBJhJETxMgxcIoCQIho33dHh+BQ+r4vikz3SkoJgMkYRHRZpjAMEZ2cnCyXyyzLxuNxjHEw6HVdLxYL9bJ0rVXwjDEpSQhBRPKsGELBwdskIpb0LXdUzwkhgKAxKaWESAxRBEXEUhRJzFFEEggyi0hiF
"text/plain": [
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x280 at 0x7F8B4DE4C450>"
]
},
"metadata": {},
"execution_count": 19
}
]
2021-07-27 13:58:27 +08:00
},
{
2021-10-15 17:37:12 +08:00
"cell_type": "markdown",
"metadata": {
"id": "1bEUwwzcVG8o"
},
"source": [
"也可以使用MMCls提供的可视化函数imshow_infos更好的展示预测结果"
2021-07-27 13:58:27 +08:00
]
2021-10-15 17:37:12 +08:00
},
{
"cell_type": "markdown",
"metadata": {
"id": "II7-IwxfRio6"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "BcSNyvAWRx20",
"outputId": "bd07d41a-3e34-4142-9a8f-a73fcaece020",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 304
}
},
"source": [
"from mmcls.core.visualization import imshow_infos\n",
"\n",
"filepath = 'data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg'\n",
"\n",
"result = {\n",
" 'pred_class:':results['pred_class'][0],\n",
" 'pred_label:': results['pred_label'][0],\n",
" 'pred_score:': results['pred_score'][0]\n",
"}\n",
"\n",
"img = imshow_infos(filepath, result)"
],
"execution_count": 20,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATMAAAEfCAYAAAAtNiETAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAFiQAABYkBbWid+gAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9ebRlaVnm+fuGPZzhzjfixpQRGWQmSTIKyqANODSgpe0EWg6rWkttKEVc1YrSq+xqS8SiBO1WKbFFHFDKshxQQSZRGURQIAGZciDJyMiMebjDGffwTf3Ht/c+JzLBv2qtXt0rD+twM+6wz9l7f9/zPu/zPu97RAiBRx+PPh59PPr4//pD/r/9Bh59PPp49PHo47/H41Ewe/Tx6OPRx/8vHvrh33jck+4IAAIJCITQgCB4AQi2t7ep65LRaMS8mJCmKcNhHx8s48kE5yDJc0IIrKyscPHiRba2trDW4r1HCAHAcDhkNptx+PBh6rqmrmsmkwmhtgAopdBakyQJUkq8h7quAYFWGik1zjnquiaEQJqmZFmCMQalJd5bqqrCWosQAqVUPGGdoqQmSTLSpI/WCUolCJEgCMgAvTylrmuMMWR5hpASY2qCCJjmeEIJhBCEEAghgBRoqcALZBCkaUpZliRJFt+T1FjrUEqxtrbG3t4+SZLwjGc8C2stn/jEJ8gHfa6M9jlx+hQPPvAAIji+/KlPY+/qZfavXWO8v0dPKY4d2mbn0DabKytIINMKvMNZR9br4wGhBHVd8tjH3kZVlYxGeyRpwurqKru7u5w7d4E8y0nTnPm8JEvzeJ8DrK6ucurUKUIInH3wLEVZgoAAFFXFcGUFIQU3P+Y0hw4f5vyFC+zv71NVFXVRogIIIfAhkKYZWdbDeZrzT0izHlJq6tqCEFjnmU3nZGmGMxWZltx2yy2UZcnBeER/MGB/f5/aGtIsY1bM2dzcZOfoEXb39tjd3WU0GiGE4PjOURIEiRaE4NCJxnvH7u41qqokyzK892gd15D3AWs81jriuvdI6XDeIgJorVEqwQdJbRxOKNLeKnWQ7E4Kro4nFBZk1sOhMNahJEgpkVJCCBA8EJlDCB7vPc5ZnDV470mkIk1T8jwlyzJ0ljKfz6mqCmMMxlbdOpdSUtdxXW9srDGdTvHes729RVkVcZ85CCHgfDx+qjRaa7w1VFXJvJizs7NDP8uZTCZYU9Hv9xFCUBQlgbi2pZQ45xiNRkwmca9vbGxw6NAhZrMZ3ntCCFhrkVKilMJaS1HMEEJQ1zXOOVZXV9nc3KSqKq5fv85gMCBNU9bW1hBCMJ/Pu700mUwIONbWBmgZcGaC1o6vfe4zGQ41G2s9tjeGWFdgbU0Ijq96zr8WXxTMWrChkdJaTa2V1lpAybIMY0tCCFRVRcDhvccYT2+osdYyGAwAGAwGzGYzAKy1hBBQShFCYDqdsr6+ztGjR7nvvvswLuCcwzkXF0PzHqx1GGNIdIpzDu/pbm4IAe89dV13F6U9DyEWoLN8jo94Nj9z1mKtxDiLcRZpFVLGfwfi6yAFKkRw9PGASCmbDQy1s+Akk/mMQT+eQ5ZleFEzGo+RiWYyn9Hr9djd3+MJT3gin//CF9gd7SOEYGtrizxN2d+9RpomJElCr9djNo6vmeU5wUcQXV1dpSxmDPsD9vb22F5dQUjFZDpGKUWv1yPLUqazMfP5vLueeZYhpUTrhCwLBL+4ZvP5vFtg1lqUUtTGEAhIIZhOp6ytr1GWJcPhkNM334ySkvFoxEFVIUJAiQjq08mEoqgYDldx1iFQSATOugZUEnpZghSCujbkvZxiMubcuXMcP34c7zzXrl1DadUFuJ2dHaxzPHj2QbIso5wXGGPY3NhgPB6RCoEWgl4vJ00SvA/Na1pmpkJKheoNSBKNTBQ1Fm8LyrJGSE/WVwQf8CHgjUF5AEUI4IOnKEuqIKhMhbEW4yDUFTYIvPNkaYr3rltrsrnmwcVA2Ovl1MbjXbvqPBCBoaorLly+RFHEc5JSkqTx3JMkQWsNxHvlve/WtLWuA1Bra7IsI5FZszdsF/irqmalP0CEQFHMsLZugDRnPp8yGh1QG8fGxgbD4ZC6rpFSMhwOWV9fZ3Nzs9trVVWRZRlpmjYA7Toi0gJZuy+MMQghWFlZYTqddmstBpUImu2+9T5iCokg0SlZ1pyjib9T1zXBO+KuXSSXjwCzjmkESQjtxRIQIlOz1pLnKYPBAOfjCTnnEbKNYhFctNYYYxgOhx1zstaSJAlVVVFVFZPJhPl8zsbGBqdPn2Y0GnFwbZeyLDsmF8EzgqlSagngPLBgXADeu7h4mu91wLx0bi14tTd+cUwJwROkgC8Gdu3FT5PuBgH45iY46wkSrLFIIamNiYzOWfI8R6cJKtEYZ7HeMVxdIdEJd91zD89+7nN52ld8Oe99//sY9nvs7u7xtC97Mlcvr3Px/HmKyRhvLVophPekSdIsYANCUFc1w+1DcUFZhws2brhqzu7uLiE4xuMx3jv29vYYDAasr68zm80pijlKJeg0IfhAWZZMJhNWV1cZDAZkeU6WZUznM4qiQEsZGVhZcfXSZbY2NhmsxHucJAnD/oD5dIpr2E8MWh4hAvHWeYQMmKqmrAz9vmRzc5MkTbl46SI4h9aSvf1der0c6wy1qUhJCUDwjpXBABc8o9EB8/mMeTGnl2Xkacp4fx8XfzECBI6qKrHOkCSKunaUZYH3kSHl+QCdaPLQAwTGlThnsC6CrUIihEIlSQQ/FJUDgUBrTZZlBCdwUiMCWOlAROBsN3OqNUpKEBoESCVQTuKVQsoIvEqpbl3u7Ox0bMg1ANjuy3aNJ0kEgTzP4/0wdZeNDAaDGMR0grWWqornq5QizRLW1taYzWaE4BgMBiRadwEsy1J0Et+PMQbnHCsrKwwGA1ZXV0nTlPvvv5/5fI5zDq01/X6/A680TZEy7whJm1m1e384HDIajeJ7rutu/4WwIDGBQF1bghekA41SoiM0wYGpLVIFtFIotYCwLwlm7SZtCY1AdWAmRNZFiqqq4mJtUFaIhNFkwtraGgcHBwyHw+5rVVXked6dnPeeqqqYz+fdDTggUvT2IrTsAOINFMgllrVgRC07W34sg9AC8Dxe+IYi++ZvPATXnQdiAYRBQJCiod4SrZL4MxnTKIWMxxIgpAQV0EmCc47B6gp1bVFpgg2eLM1I8zxeeBmZ6d5on4PJGKEVUilKUyHGI4oipgwHBwfU8ynBWKyzOGOpTc0gz7HWMZvFNGM4HLKxscFkNifv9zm0s82lSxe4evUqk8kYqeAxjznNZDLpom28ZoE0Vc15i+YeQm0qVpNVDh3apt/vs3+wz7XrLm6Wfo/pfEptU+6//wvs7Owwn0yajdSnms+pqhqtNb0sxxhLVZQEIeP1UAkCgRIBayqUEmxtbmCqkiuXLtDPM7JEceHiOU6ePMnJkye4/4EzFMWcvJ9y/tyDrK2tcfLEMR586CG2Ntbw3nPxwnnWV4ZoJN6BdY6iKKjrEqUEw7U1rDVcv349pnC1w9rAcLhCmiUICcIEjCtARVklIAhCIKRAao1QCVIkyCCxicOmOYkXWKFwCKx1VPM53om4Oa0lOBNZlVR47zk4OIAQEAQSHZl3kmiEBKUk65ubcU3jusAe12MMxpHFOMrSgwDratw8voa1lkHeawJewNoaWxuED2SJJtUx0KepRooErRTeGkxVIhH0+z0O7xxnPB6zv7+P95719XWGwyHWWvb399nb2yPPc5IkQSnVZVmR6ORIGdPaNgNzzjU4Ecib4Oi9pyxj2t8ep2WipmF5AqgMZKmOMpeLmZ21oEOI90QsyMwXBbMbgS0yoDbvjICUkiTJErX0cSOLqBVVVUWaxnRwe3ub8XhMmqY3pC0tdZ3NZhwcHHRobYyJFBPo9/sopZr0NaaZUdtp39+N7zdqa/4Raeby+RhjIAiUilFAiJiy0pyrEnTAuIiGi/9uL7jQi0gqpUTqRYQpbY0AEp2RK01paqhqamcZTcZoramrGuscQik++elPce3KVWZFwbwuO
"text/plain": [
"<Figure size 300.01x280.01 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aEfjTbP4TWGM"
},
"source": [
""
],
"execution_count": 20,
"outputs": []
2021-07-27 13:58:27 +08:00
}
2021-10-15 17:37:12 +08:00
]
2021-07-27 13:58:27 +08:00
}