mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
1224 lines
422 KiB
Plaintext
1224 lines
422 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "XjQxmm04iTx4",
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"<a href=\"https://github.com/open-mmlab/mmclassification/blob/master/docs_zh-CN/tutorials/MMClassification_tools_cn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "4z0JDgisPRr-"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# 基于 Colab 的 MMClassification tools 教程\n",
|
|||
|
"\n",
|
|||
|
"在本教程中会介绍如下内容:\n",
|
|||
|
"\n",
|
|||
|
"* 如何安装 MMCls\n",
|
|||
|
"* 数据下载\n",
|
|||
|
"* 准备配置文件\n",
|
|||
|
"* shell 命令行"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "inm7Ciy5PXrU"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"## 安装 MMClassification\n",
|
|||
|
"\n",
|
|||
|
"在使用 MMClassification 之前,我们需要配置环境,步骤如下:\n",
|
|||
|
"\n",
|
|||
|
"- 安装 Python, CUDA, C/C++ compiler 和 git\n",
|
|||
|
"- 安装 PyTorch (CUDA 版)\n",
|
|||
|
"- 安装 mmcv\n",
|
|||
|
"- 克隆 mmcls github 代码库然后安装\n",
|
|||
|
"\n",
|
|||
|
"因为我们在 Google Colab 进行实验,Colab 已经帮我们完成了基本的配置,我们可以直接跳过前面两个步骤 。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "TDOxbcDvPbNk"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"### 检查环境"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "c6MbAw10iUJI",
|
|||
|
"outputId": "6306a973-e45c-4067-dc52-2045068f64da"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/content\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%cd /content"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "4IyFL3MaiYRu",
|
|||
|
"outputId": "ec9f7005-a2f7-4a24-d785-0e3190b14342"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/content\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"!pwd"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "DMw7QwvpiiUO",
|
|||
|
"outputId": "9b61c9c5-5287-4b8b-fc3c-6bb4b66d44f0"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"nvcc: NVIDIA (R) Cuda compiler driver\n",
|
|||
|
"Copyright (c) 2005-2020 NVIDIA Corporation\n",
|
|||
|
"Built on Wed_Jul_22_19:09:09_PDT_2020\n",
|
|||
|
"Cuda compilation tools, release 11.0, V11.0.221\n",
|
|||
|
"Build cuda_11.0_bu.TC445_37.28845127_0\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 检查 nvcc 版本\n",
|
|||
|
"!nvcc -V"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "4VIBU7Fain4D",
|
|||
|
"outputId": "cf671c76-7d4f-45df-b9a9-7c30aaba4666"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
|
|||
|
"Copyright (C) 2017 Free Software Foundation, Inc.\n",
|
|||
|
"This is free software; see the source for copying conditions. There is NO\n",
|
|||
|
"warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 检查 GCC 版本\n",
|
|||
|
"!gcc --version"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "24lDLCqFisZ9",
|
|||
|
"outputId": "5a45acfc-7581-4925-f4bd-00be7ae9cba0"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"1.9.0+cu102\n",
|
|||
|
"True\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 检查 PyTorch 的安装情况\n",
|
|||
|
"import torch, torchvision\n",
|
|||
|
"print(torch.__version__)\n",
|
|||
|
"print(torch.cuda.is_available())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "R2aZNLUwizBs"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"### 安装 MMCV\n",
|
|||
|
"\n",
|
|||
|
"MMCV 是 OpenMMLab 代码库的基础库。Linux 环境的安装 whl 包已经提前打包好,大家可以直接下载安装。\n",
|
|||
|
"\n",
|
|||
|
"需要注意 PyTorch 和 CUDA 版本,确保能够正常安装。\n",
|
|||
|
"\n",
|
|||
|
"在前面的步骤中,我们输出了环境中 CUDA 和 PyTorch 的版本,分别是 11.0 和 1.9.0,我们需要选择相应的 MMCV 版本。\n",
|
|||
|
"\n",
|
|||
|
"另外,也可以安装完整版的 MMCV-full,它包含所有的特性以及丰富的开箱即用的 CUDA 算子。需要注意的是完整版本可能需要更长时间来编译。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "nla40LrLi7oo",
|
|||
|
"outputId": "bcf8a25b-5d9a-4b7e-8c9d-546b451d5efc"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Looking in links: https://download.openmmlab.com/mmcv/dist/cu110/torch1.9.0/index.html\n",
|
|||
|
"Collecting mmcv\n",
|
|||
|
" Downloading mmcv-1.3.9.tar.gz (313 kB)\n",
|
|||
|
"\u001b[K |████████████████████████████████| 313 kB 7.6 MB/s \n",
|
|||
|
"\u001b[?25hCollecting addict\n",
|
|||
|
" Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)\n",
|
|||
|
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcv) (1.19.5)\n",
|
|||
|
"Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from mmcv) (7.1.2)\n",
|
|||
|
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from mmcv) (3.13)\n",
|
|||
|
"Collecting yapf\n",
|
|||
|
" Downloading yapf-0.31.0-py2.py3-none-any.whl (185 kB)\n",
|
|||
|
"\u001b[K |████████████████████████████████| 185 kB 62.8 MB/s \n",
|
|||
|
"\u001b[?25hBuilding wheels for collected packages: mmcv\n",
|
|||
|
" Building wheel for mmcv (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
|||
|
" Created wheel for mmcv: filename=mmcv-1.3.9-py2.py3-none-any.whl size=451832 sha256=2ae32dad5995b9cd1b677d2d755fa7a8a64ec7705b98f208c727cdf9860ca1db\n",
|
|||
|
" Stored in directory: /root/.cache/pip/wheels/88/48/bf/655e136aea5534d7a9a85fe247fee7957178fc19cf79dda602\n",
|
|||
|
"Successfully built mmcv\n",
|
|||
|
"Installing collected packages: yapf, addict, mmcv\n",
|
|||
|
"Successfully installed addict-2.4.0 mmcv-1.3.9 yapf-0.31.0\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 安装 mmcv\n",
|
|||
|
"!pip install mmcv -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.9.0/index.html\n",
|
|||
|
"# !pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.9.0/index.html"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "GDTUrYvXjlRb"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"### 克隆并安装 MMCls\n",
|
|||
|
"\n",
|
|||
|
"接着,我们从 github 上克隆下 mmcls 最新代码库并进行安装。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "Bwme6tWHjl5s",
|
|||
|
"outputId": "de4b0126-ab3f-40be-b3e0-2a79f9b1ef10"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Cloning into 'mmclassification'...\n",
|
|||
|
"remote: Enumerating objects: 3161, done.\u001b[K\n",
|
|||
|
"remote: Counting objects: 100% (12/12), done.\u001b[K\n",
|
|||
|
"remote: Compressing objects: 100% (12/12), done.\u001b[K\n",
|
|||
|
"remote: Total 3161 (delta 2), reused 5 (delta 0), pack-reused 3149\u001b[K\n",
|
|||
|
"Receiving objects: 100% (3161/3161), 2.81 MiB | 7.35 MiB/s, done.\n",
|
|||
|
"Resolving deltas: 100% (2039/2039), done.\n",
|
|||
|
"/content/mmclassification\n",
|
|||
|
"Obtaining file:///content/mmclassification\n",
|
|||
|
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from mmcls==0.13.0) (3.2.2)\n",
|
|||
|
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcls==0.13.0) (1.19.5)\n",
|
|||
|
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.13.0) (2.4.7)\n",
|
|||
|
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.13.0) (1.3.1)\n",
|
|||
|
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.13.0) (2.8.1)\n",
|
|||
|
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.13.0) (0.10.0)\n",
|
|||
|
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->mmcls==0.13.0) (1.15.0)\n",
|
|||
|
"Installing collected packages: mmcls\n",
|
|||
|
" Running setup.py develop for mmcls\n",
|
|||
|
"Successfully installed mmcls-0.13.0\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 下载 mmcls 代码库\n",
|
|||
|
"!git clone https://github.com/open-mmlab/mmclassification.git\n",
|
|||
|
"%cd mmclassification/\n",
|
|||
|
"\n",
|
|||
|
"# 从源码安装 MMClassification\n",
|
|||
|
"!pip install -e . "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "hFg_oSG4j3zB",
|
|||
|
"outputId": "f56eff4e-70ee-4a0c-9b3b-384c5f982c61"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"0.13.0\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 检查 MMClassification 的安装情况\n",
|
|||
|
"import mmcls\n",
|
|||
|
"print(mmcls.__version__)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "PkfqxfLIQVFM",
|
|||
|
"outputId": "cfff514f-0642-4743-c30c-1fec4190c856"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"--2021-07-27 03:02:21-- https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
|
|||
|
"Resolving download.openmmlab.com (download.openmmlab.com)... 47.88.36.78\n",
|
|||
|
"Connecting to download.openmmlab.com (download.openmmlab.com)|47.88.36.78|:443... connected.\n",
|
|||
|
"HTTP request sent, awaiting response... 200 OK\n",
|
|||
|
"Length: 14206911 (14M) [application/octet-stream]\n",
|
|||
|
"Saving to: ‘checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth’\n",
|
|||
|
"\n",
|
|||
|
"mobilenet_v2_batch2 100%[===================>] 13.55M 13.3MB/s in 1.0s \n",
|
|||
|
"\n",
|
|||
|
"2021-07-27 03:02:23 (13.3 MB/s) - ‘checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth’ saved [14206911/14206911]\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 下载预训练模型\n",
|
|||
|
"!mkdir checkpoints\n",
|
|||
|
"!wget https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth -P checkpoints"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "sAt5v0s3Qei4"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"### 下载数据集\n",
|
|||
|
"\n",
|
|||
|
"下载数据集\n",
|
|||
|
"\n",
|
|||
|
"这里我们使用 [猫狗分类](https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0) 数据集作为示例。 方便起见,我们重新打包了数据集。原始数据集可以在 [kaggle链接](https://www.kaggle.com/tongpython/cat-and-dog) 上找到。 这个数据集包含 8 千张训练图像和 2千张测试图像。一共两个类别,分别是:猫,狗。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "XHCHnKb_Qd3P",
|
|||
|
"outputId": "cd6a97e7-f520-4213-d5b4-9f006532c2b5"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"--2021-07-27 03:02:23-- https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0\n",
|
|||
|
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212\n",
|
|||
|
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected.\n",
|
|||
|
"HTTP request sent, awaiting response... 301 Moved Permanently\n",
|
|||
|
"Location: /s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip [following]\n",
|
|||
|
"--2021-07-27 03:02:23-- https://www.dropbox.com/s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip\n",
|
|||
|
"Reusing existing connection to www.dropbox.com:443.\n",
|
|||
|
"HTTP request sent, awaiting response... 302 Found\n",
|
|||
|
"Location: https://uc16f4d1c7c9628cf76d673f1cfe.dl.dropboxusercontent.com/cd/0/inline/BTEEPjXBsrs1O8BGfaQGlADK-9WCEv8RbwE1jpSbGL3q27x3kOn1ySXJCas_OonX-JcAjjFSu0t500sHPkhKTJ_8hfj0H3N0dbkL8gXVwGSapdV20KixJMkch9MIZPrSI5lhphCba5Q7J9htSx-fvHqV/file# [following]\n",
|
|||
|
"--2021-07-27 03:02:24-- https://uc16f4d1c7c9628cf76d673f1cfe.dl.dropboxusercontent.com/cd/0/inline/BTEEPjXBsrs1O8BGfaQGlADK-9WCEv8RbwE1jpSbGL3q27x3kOn1ySXJCas_OonX-JcAjjFSu0t500sHPkhKTJ_8hfj0H3N0dbkL8gXVwGSapdV20KixJMkch9MIZPrSI5lhphCba5Q7J9htSx-fvHqV/file\n",
|
|||
|
"Resolving uc16f4d1c7c9628cf76d673f1cfe.dl.dropboxusercontent.com (uc16f4d1c7c9628cf76d673f1cfe.dl.dropboxusercontent.com)... 162.125.2.15, 2620:100:6022:15::a27d:420f\n",
|
|||
|
"Connecting to uc16f4d1c7c9628cf76d673f1cfe.dl.dropboxusercontent.com (uc16f4d1c7c9628cf76d673f1cfe.dl.dropboxusercontent.com)|162.125.2.15|:443... connected.\n",
|
|||
|
"HTTP request sent, awaiting response... 302 Found\n",
|
|||
|
"Location: /cd/0/inline2/BTEEQ4z4k-xmaTVGrSZDExfzyc1S3rZCuMd6YvpvOtprRVBBqT7f4ISwbPtdDGZ9x2_i4R8JWxI5pdYX8UAy4VLR89S73nF_WueY3yhALGzDa3KNweauQAZeUXACVu7buhI96dCJCrs1FggmD9WNXdq9-EQ-Odgh3xdaWCsViBfc6Dj8DuGKW0sK5j-iIzrIJ7uMVsZN2cLju7BRLez3Njjy4sRGxQ58ZY3tagNpBIUzEZIn5Yobif2aNmP9po79DGSPt0_LWsnINmFADcQxnpAPweCU6l6VEL5rHd31rEsrEvgOkH8h7CWpSohdrUjKVo79FRmYz65LzNKwzg_H3Dq0LNty5thfmZ7cZKMv_DBDg2EnzbaDXgSRAKXJNlOocSQ/file [following]\n",
|
|||
|
"--2021-07-27 03:02:24-- https://uc16f4d1c7c9628cf76d673f1cfe.dl.dropboxusercontent.com/cd/0/inline2/BTEEQ4z4k-xmaTVGrSZDExfzyc1S3rZCuMd6YvpvOtprRVBBqT7f4ISwbPtdDGZ9x2_i4R8JWxI5pdYX8UAy4VLR89S73nF_WueY3yhALGzDa3KNweauQAZeUXACVu7buhI96dCJCrs1FggmD9WNXdq9-EQ-Odgh3xdaWCsViBfc6Dj8DuGKW0sK5j-iIzrIJ7uMVsZN2cLju7BRLez3Njjy4sRGxQ58ZY3tagNpBIUzEZIn5Yobif2aNmP9po79DGSPt0_LWsnINmFADcQxnpAPweCU6l6VEL5rHd31rEsrEvgOkH8h7CWpSohdrUjKVo79FRmYz65LzNKwzg_H3Dq0LNty5thfmZ7cZKMv_DBDg2EnzbaDXgSRAKXJNlOocSQ/file\n",
|
|||
|
"Reusing existing connection to uc16f4d1c7c9628cf76d673f1cfe.dl.dropboxusercontent.com:443.\n",
|
|||
|
"HTTP request sent, awaiting response... 200 OK\n",
|
|||
|
"Length: 228802825 (218M) [application/zip]\n",
|
|||
|
"Saving to: ‘cats_dogs_dataset.zip’\n",
|
|||
|
"\n",
|
|||
|
"cats_dogs_dataset.z 100%[===================>] 218.20M 97.3MB/s in 2.2s \n",
|
|||
|
"\n",
|
|||
|
"2021-07-27 03:02:27 (97.3 MB/s) - ‘cats_dogs_dataset.zip’ saved [228802825/228802825]\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 下载分类数据集文件\n",
|
|||
|
"!wget https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0 -O cats_dogs_dataset.zip\n",
|
|||
|
"!mkdir data\n",
|
|||
|
"!unzip -q cats_dogs_dataset.zip -d ./data/"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "e4t2P2aTQokX"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"完成下载和解压之后, \"Cats and Dogs Dataset\" 文件夹下的文件结构如下:\n",
|
|||
|
"```\n",
|
|||
|
"data/cats_dogs_dataset\n",
|
|||
|
"├── classes.txt\n",
|
|||
|
"├── test.txt\n",
|
|||
|
"├── val.txt\n",
|
|||
|
"├── training_set\n",
|
|||
|
"│ ├── training_set\n",
|
|||
|
"│ │ ├── cats\n",
|
|||
|
"│ │ │ ├── cat.1.jpg\n",
|
|||
|
"│ │ │ ├── cat.2.jpg\n",
|
|||
|
"│ │ │ ├── ...\n",
|
|||
|
"│ │ ├── dogs\n",
|
|||
|
"│ │ │ ├── dog.2.jpg\n",
|
|||
|
"│ │ │ ├── dog.3.jpg\n",
|
|||
|
"│ │ │ ├── ...\n",
|
|||
|
"├── val_set\n",
|
|||
|
"│ ├── val_set\n",
|
|||
|
"│ │ ├── cats\n",
|
|||
|
"│ │ │ ├── cat.3.jpg\n",
|
|||
|
"│ │ │ ├── cat.5.jpg\n",
|
|||
|
"│ │ │ ├── ...\n",
|
|||
|
"│ │ ├── dogs\n",
|
|||
|
"│ │ │ ├── dog.1.jpg\n",
|
|||
|
"│ │ │ ├── dog.6.jpg\n",
|
|||
|
"│ │ │ ├── ...\n",
|
|||
|
"├── test_set\n",
|
|||
|
"│ ├── test_set\n",
|
|||
|
"│ │ ├── cats\n",
|
|||
|
"│ │ │ ├── cat.4001.jpg\n",
|
|||
|
"│ │ │ ├── cat.4002.jpg\n",
|
|||
|
"│ │ │ ├── ...\n",
|
|||
|
"│ │ ├── dogs\n",
|
|||
|
"│ │ │ ├── dog.4001.jpg\n",
|
|||
|
"│ │ │ ├── dog.4002.jpg\n",
|
|||
|
"│ │ │ ├── ...\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"可以通过 shell 命令 `tree data/cats_dogs_dataset` 查看文件结构。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 297
|
|||
|
},
|
|||
|
"id": "46tyHTdtQy_Z",
|
|||
|
"outputId": "f363f236-7a2a-466c-84f3-ae3feca63b8c"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEYCAIAAABp9FyZAAEAAElEQVR4nJT96ZMkWXIfCKrqe88uv+KOvCuzju7qrmo00A00QZAcoZAznPmwIkP+p/NhZWVnZTAHSBALAmig+u6ursrKMzIyDj/teofqflB3S89qckTWpCTKM8Lc3ezZ0+unP1XF737+KQAgGEQDQCAoIgD04MGDul7f3Nx0fTsalbPZJCa/XC5dOQ4hGGMAoGka55y1tu/7LMu89865zWbz+PHj+Xy+2WxGo1G3rokIAJgZEa3N9L3C4Jwzxngfu64TkbIsq6pqu40xJqXU922MkYiMMSKS56Uh61zubGmtM5QjGkTMDZRl2fd9CKGoSkTx3gthSgkRBCGllFIKnAAAEbGXPM/1q7vOM3OM0ZBLKcWYUkrM/PHH33n58mWe5yklnlW3t7e5M6fHJ7PRaLOY3769zAxNsvz06PDJwwe5ofnV1eNHD2+uru7du3O7aeu6Nha97/7n//l/fvHimQ8dIi4Wi/V6vVk34/H0+voa0Xzy8XdfvXrVtt39+/c//vjjt2/fPn32TZ7nDBKZY4x37997+/btn/zpj9++fftX//E//vCHP1yv1yWS98FaK2i+//3vv3nztutDluWcQMjEwNfXN/cfPHrz5o0IVlVlQvfpp5+mlH72s5+NZ9MY48nZ6Xy5XK1Wn37/e3fu3Hnx6uVvf/vbPoQ7d+6cn5/317ezg8lqteq6phqVV1eXm83q0aNHTdOklFKSlASBrM2yLHM2ny9eiAgB5nleFBWhjZG7kMAUybibTXdxs1yHaMspOtf7aAiMMdZaRERhRHTOOUNd13FKIfQxRokJUay1xpibxTzGKCLWWiJCkizL8jy/vb1JKT169Kiu133fHx8f397etm07nU67rqubdex9nueTycQY5JjeXF4cTmfHx8fG4Pz2drVaTSaTO3furNb1YrHw3mdZllKq6xoRx+NxWZbM7JyLMW42G+dcVVUppc1mlee5916/6/T0tGmaZ8+e3bt3DxHLsgSAvu+dcyIyn8/J8PHxIVFs13Mi/z/+9/9qOrWzcX7n7lEKTQi9QLKICAAg8K2j7/uUEiKKSAih67rEIYRQ94uiKEII0+n06urq5OQkhAAAMUa9iJQSAFRVZa1dr9ciwszGGN36RDal5L1HIGYGABHRaxARXe5vXwrAVoaBmVl2B4CIiE/RGOO9DyGQNYMQhhCICA3q240xiGj0f4jW2rqui6Lquq4sS9/HoiiOjo5fvXo1mUx+/OMf379//2/+5m8ePnz4y9fPfvSjH/3t3/y1M/aPP//8F/+0qKpqeXuTCSBi0zTZeDSbzbquy/P86uqqmB5WVeUyU9dS1/XV1VVMvqqqs7Oz5XK5XC6rajydTlerze3tbUppPB4bY4qi0KUmIu9749xqtdKlOz8/J6I/+7M/CyE450Lb6foj4dOnT7OsCCGEEIu8stYSwsHBASBba5m5qoo337yq6/WdO3em0/Hd+3frun79+uXZ2Vni8PTpVwcH0zx3JydHTdN88+zrsspzgL7vjTHT6bQo87atQ+ibpgkhWGuzLBPBGFIIqWkakM7lGcfEzMzsvUdIPiYfOJ9UQGSMIWsNCxElkARCAqr4AACFRcR7j8LMzCnF6FNKFinLbFEUWZYdnZ40TdN1XQghhOBD1/e9bjYVifV6HUIYjUZElBeubVuBlGWZI6OLEH3o+x4AyrIUkfl8EUI4Pj4moouLCxa01qrCreu6russy0RkPB7Xda17O8syIkopxRiZue97773esu7YsixVjJ1zRDRsY2stS4wxWgPWWmt1JwMiJh9i4pQSS6Rhl8v7h/ceEVVy9r9YHwYijkYjABiNRnoPw14XEbWBjx8/zvNcf6NXP0ia3ttwV2ru9CnuiyUi6tbUW8L3D/3GuDtCivsHMydhERARRgAAIkLEBNLH4FNcbtYhJUHMigKtuV0uNm2zbuou+Murq+9//rnNsuvb2zzPz8/PP//88+l0aoxxzumNq6LRVTo6Ouq67ujoaL1e3717dzwep5SyLBuNRrPZjJlXq9VqtXLO6T7I81y1qd71crns+17dAedc13UoYoy5vLw8OTl5/vz5D37wg/Ozs6PDwzzPmVkfZmbp5cuXKSUU0MtAxLars8xJ4uOTQ/VN7t69+0//9E+67Z49e6ZKJ8Z4cHDw4ZMnP//5z5tNrdvx0f0Hq/miqde31zd92xhCjoEADdJ6ueqaNoVokMo8K/LMEkYfurZGREEQER9DCMHHwMyM0HVt27ad72OMMaXOt13X6T3GGL33IXh1Urz3XddlWWYt7R7rdqsw87Nnz54/f/7ixYu3b99u6lVKyVqbZS7LsizLdFPpNiADRJQ4OOemo/F0Oi2KAoD188fjMQC0bRtDyLJMDd1yuby+vo4xZllmjCGi8Xh8dnZ2584dAPDe13UdY9RtrIrDWqsbTLdTCAERp9Op6ik9Qfe8un4JpO9DCMHaLM9LvVRh8N7HGEUQwZCezcxpdwwibq0djUZlWaoa0C/QbeScCyGMx2MRybJMnQoR6ft+vV5fXV0h4ocffnh6ejoajVSRqIx1Xafr+L6PalWwU0qDJRzEbBBLRCQilVh98U4g3700QoiI+pCstcZYEBSGFDlF7r0XgN57Y20fvMszm7nRZFxUZe/94dFRluf/5e//7uPvfPKTP/9n63pzfHz8/Pnzf/2v//WDBw9+9atf6dqpsiyKIqWkznnTNKenp9PpVNVn0zSbzeby8lK90BDCb37zG2vt2dlZ3/d1XRdFoU+rbdurq6vlcpmEp9Pp0dGRcw4AiqLo2857//vf/q5erfURTMbjsiy7rh3WIYTgMiPCSGJAdspbPnz85OT0qOub2XTcdvXzF98kDin0PnQI7H2XGTo7OyGU2/n1Yn5jCI4OZ02z6ft+uVyq8VksFl3XuMxYR71v62a92azUThZFUZalc86nkFIMKaaUYkqIaDNXlCUQCoIxJitcVhY2z4xzLs8QkZm93+p0IrKWdD/oHTnnNMbRM+/cuXN0dKSO1aCRdds4ZxCxKAqXGR+6EEJd11VejMtK/UlriQCNMS6zR7ODvu/bth6NRqPRqGmaul7nuauqSpcxxjgej+/fv3/37t3pdHp9fb1er9u2TSk553RzWmuLotDtp0ZPhVDV7mCo9F7UzAij97Hvg6AGVsAhMbP3UZIYpNwVpHel4qcirnKoFl9XRKVFhTDP89VqhYjz+Xwymcznc0TUB6OGmJnrum6aRgVYTZBzTkVCpUtvYJCbQa50lQch3JdD3h3f0hr6sIczhRCAEI2xmTHGGAN7QktEYMgVeQKZHh4kEJtnPsWsKNQe5lXJCJfXV6t644ocDLVtu1gs1Dm/vr6u63qz2YiIOkjOOdU7RHRwcPDJJ58Q0ZMnTz755JOyLL/88stvvvkGAL773e/OZjO9AFVn6l8YY5yzIty2zXg8un//3tnZ6cHBjAhR+PDo4JunX2eZ+9//978cFcXLZ8/KLJtOp23b6uecHh2v18uUUgpxvV7H5HNrEofet5PJ+N6dO9Px+M2b13/ywz/65puvr64u//zPf3JyeHB5eVGV+Xqz/Onf/5c//qPPLSGBfP797/3iZ19Yg9aStSTAXdfU9TrGOBqNzs/PsyyLMa5Wq+Vy7n3nMluWeVnmiCiEYBAIhRhIyFrnXDmqqqoqRuVkNj04OJhOp7OD6eHxgbWGCNWl6vs+hF631mJ5u9lsvPcAbK3VHQIAR0dHBwcHk+moKDPdIQCit8/MbdsmDjHGuq5VIZZlqVYhRh99UIMxLisAcM6URWGtTSl436FAmRcff/zxeDxWpVMUxWQySSldXl7e3Nzoe3UXqQEnIjUY+vtB8FRpMrPamGG3ExGA+rEphMBJhJETxMgxcIoCQIho33dHh+BQ+r4vikz3SkoJgMkYRHRZpjAMEZ2cnCyXyyzLxuNxjHEw6HVdLxYL9bJ0rVXwjDEpSQhBRPKsGELBwdskIpb0LXdUzwkhgKAxKaWESAxRBEXEUhRJzFFEEggyi0hiFo0JrUE
|
|||
|
"text/plain": [
|
|||
|
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x280 at 0x7FB6DBE1EA90>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 获取一张图像可视化\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"Image.open('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "My5Z6p7pQ3UC"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"### 支持新的数据集\n",
|
|||
|
"\n",
|
|||
|
"MMClassification 要求数据集必须将图像和标签放在同级目录下。有两种方式可以支持自定义数据集。\n",
|
|||
|
"\n",
|
|||
|
"最简单的方式就是将数据集转换成现有的数据集格式(比如 ImageNet)。另一种方式就是新建一个新的数据集类。细节可以查看 [文档](https://github.com/open-mmlab/mmclassification/blob/master/docs_zh-CN/tutorials/new_dataset.md).\n",
|
|||
|
"\n",
|
|||
|
"在这个教程中,为了方便学习,我们已经将 “猫狗分类数据集” 按照 ImageNet 的数据集格式进行了整理。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "P335gKt9Q5U-"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"标准文件包括:\n",
|
|||
|
"\n",
|
|||
|
"1. 类别列表。每行代表一个类别。\n",
|
|||
|
" ```\n",
|
|||
|
" cats\n",
|
|||
|
" dogs\n",
|
|||
|
" ```\n",
|
|||
|
"2. 训练/验证/测试标签。\n",
|
|||
|
"每行包括一个文件名和其相对应的标签。\n",
|
|||
|
" ```\n",
|
|||
|
" ...\n",
|
|||
|
" cats/cat.3769.jpg 0\n",
|
|||
|
" cats/cat.882.jpg 0\n",
|
|||
|
" ...\n",
|
|||
|
" dogs/dog.3881.jpg 1\n",
|
|||
|
" dogs/dog.3377.jpg 1\n",
|
|||
|
" ...\n",
|
|||
|
" ```"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "BafQ7ijBQ8N_"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"## 命令行工具的使用\n",
|
|||
|
"\n",
|
|||
|
"MMCls 同样提供了命令行工具,提供如下功能:\n",
|
|||
|
"\n",
|
|||
|
"1. 模型训练\n",
|
|||
|
"2. 模型微调\n",
|
|||
|
"3. 模型测试\n",
|
|||
|
"4. 推理计算\n",
|
|||
|
"\n",
|
|||
|
"模型训练的过程与模型微调的过程一致,我们已经看到 Python API 的推理和模型微调过程。接下来我们将会看到如何使用命令行工具完成这些任务。更过细节可以参考 [文档](https://github.com/open-mmlab/mmclassification/blob/master/docs_zh-CN/getting_started.md)."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "Aj5cGMihURrZ"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"### 模型微调\n",
|
|||
|
"\n",
|
|||
|
"通过命令行进行模型微调步骤如下:\n",
|
|||
|
"\n",
|
|||
|
"1. 准备自定义数据集\n",
|
|||
|
"2. 数据集适配 MMCls 要求\n",
|
|||
|
"3. 在 py 脚本中修改配置文件\n",
|
|||
|
"4. 使用命令行工具进行模型微调\n",
|
|||
|
"\n",
|
|||
|
"第1,2步与之前的介绍一致,我们将会介绍后面2个步骤的内容。\n",
|
|||
|
"\n",
|
|||
|
"#### 在 py 脚本中修改配置文件\n",
|
|||
|
"\n",
|
|||
|
"为了能够复用不同配置文件中常用的部分,我们支持多配置文件继承。比如模型微调 MobileNetV2 ,新的配置文件可以通过继承 `configs/_base_/models/mobilenet_v2_1x.py` 来创建模型的基本结构。 继承 `configs/_base_/datasets/cats_dogs_dataset.py` 来使用之前定义好的数据集。继承 `configs/_base_/schedules/cats_dogs_finetune.py` 来自定义学习率策略。为了能够运行设定的学习率策略,还需要继承 `configs/_base_/default_runtime.py`.\n",
|
|||
|
"\n",
|
|||
|
"最后的配置文件应该显示如下\n",
|
|||
|
"\n",
|
|||
|
"```\n",
|
|||
|
"# Save to \"configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\"\n",
|
|||
|
"_base_ = [\n",
|
|||
|
" '../_base_/models/mobilenet_v2_1x.py',\n",
|
|||
|
" '../_base_/datasets/imagenet_bs32_pil_resize.py',\n",
|
|||
|
" '../_base_/schedules/imagenet_bs256_epochstep.py',\n",
|
|||
|
" '../_base_/default_runtime.py'\n",
|
|||
|
"]\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"此外,也可以不使用这种继承的方式,而直接构建完整的配置文件,比如 `configs/mnist/lenet5.py`.\n",
|
|||
|
"\n",
|
|||
|
"这里我们使用了重构好的数据集,如果想要完全使用自定义的数据集,还需要重新构建一个数据集配置,这个配置会覆盖之前的内容。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "_UV3oBhLRG8B"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"首先,修改模型配置并保存为 `configs/_base_/models/mobilenet_v2_1x_cats_dogs.py`。这个新的配置文件需要根据分类问题的类别来调整模型 `head` 的 `num_classes`。预训练模型的权重,除了最后一层线性层,其他的部分一般选择复用。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "8QfM4qBeWIQh",
|
|||
|
"outputId": "ec006536-5586-446a-a45b-b9cbe44a57de"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Writing configs/_base_/models/mobilenet_v2_1x_cats_dogs.py\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%writefile configs/_base_/models/mobilenet_v2_1x_cats_dogs.py\n",
|
|||
|
"_base_ = ['./mobilenet_v2_1x.py']\n",
|
|||
|
"model = dict(\n",
|
|||
|
" backbone=dict(\n",
|
|||
|
" init_cfg = dict(\n",
|
|||
|
" type='Pretrained', \n",
|
|||
|
" checkpoint='checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', \n",
|
|||
|
" prefix='backbone')\n",
|
|||
|
" ),\n",
|
|||
|
" head=dict(\n",
|
|||
|
" num_classes=2,\n",
|
|||
|
" topk = (1, )\n",
|
|||
|
" ))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "F2bjgpsZRKp1"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"第二,数据配置,保存为 `configs/_base_/datasets/cats_dogs_dataset.py`."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "DMIp07L4Wn80",
|
|||
|
"outputId": "d35bdbd0-a65f-4354-86df-f53d595492fa"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Writing configs/_base_/datasets/cats_dogs_dataset.py\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%writefile configs/_base_/datasets/cats_dogs_dataset.py\n",
|
|||
|
"_base_ = ['./imagenet_bs32.py']\n",
|
|||
|
"img_norm_cfg = dict(\n",
|
|||
|
" mean=[124.508, 116.050, 106.438],\n",
|
|||
|
" std=[58.577, 57.310, 57.437],\n",
|
|||
|
" to_rgb=True)\n",
|
|||
|
"\n",
|
|||
|
"data = dict(\n",
|
|||
|
" # 每个 gpu 上的 batch size 和 num_workers 设置,根据计算机情况设置\n",
|
|||
|
" samples_per_gpu = 32,\n",
|
|||
|
" workers_per_gpu=2,\n",
|
|||
|
" # 指定训练集路径\n",
|
|||
|
" train = dict(\n",
|
|||
|
" data_prefix = 'data/cats_dogs_dataset/training_set/training_set',\n",
|
|||
|
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
|
|||
|
" ),\n",
|
|||
|
" # 指定验证集路径\n",
|
|||
|
" val = dict(\n",
|
|||
|
" data_prefix = 'data/cats_dogs_dataset/val_set/val_set',\n",
|
|||
|
" ann_file = 'data/cats_dogs_dataset/val.txt',\n",
|
|||
|
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
|
|||
|
" ),\n",
|
|||
|
" # 指定测试集路径\n",
|
|||
|
" test = dict(\n",
|
|||
|
" data_prefix = 'data/cats_dogs_dataset/test_set/test_set',\n",
|
|||
|
" ann_file = 'data/cats_dogs_dataset/test.txt',\n",
|
|||
|
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
|
|||
|
" )\n",
|
|||
|
")\n",
|
|||
|
"# 修改评估指标设置\n",
|
|||
|
"evaluation = dict(metric_options={'topk': (1, )})"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "_lxAl1cSRM_D"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"第三是学习率策略。模型微调的策略与默认策略差别很大。微调一般会要求更小的学习率和更少的训练周期。最后保存为 `configs/_base_/schedules/cats_dogs_finetune.py`."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "6-JTFNaDWzFQ",
|
|||
|
"outputId": "69ceb2be-69e6-47bc-d014-04e3e65a92f4"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Writing configs/_base_/schedules/cats_dogs_finetune.py\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%writefile configs/_base_/schedules/cats_dogs_finetune.py\n",
|
|||
|
"# 优化器设置\n",
|
|||
|
"# 设定针对 batch size 为 128 的学习率\n",
|
|||
|
"optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
|
|||
|
"optimizer_config = dict(grad_clip=None)\n",
|
|||
|
"# 学习率策略\n",
|
|||
|
"lr_config = dict(policy='step', step=[1])\n",
|
|||
|
"runner = dict(type='EpochBasedRunner', max_epochs=2)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "ofZoBfseROf1"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"最后,运行环境配置。直接使用默认的配置。我们将上述所有修改和保存的配置文件集中到一个文件中,并保存为 `configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py`.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "3tp9C42uXgRD",
|
|||
|
"outputId": "c0c9e205-a18c-412f-a031-081b3928c0df"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Writing configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%writefile configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py\n",
|
|||
|
"_base_ = [\n",
|
|||
|
" '../_base_/models/mobilenet_v2_1x_cats_dogs.py', '../_base_/datasets/cats_dogs_dataset.py',\n",
|
|||
|
" '../_base_/schedules/cats_dogs_finetune.py', '../_base_/default_runtime.py'\n",
|
|||
|
"]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "chLX7bL3RP2F"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"#### 使用命令行进行模型微调\n",
|
|||
|
"\n",
|
|||
|
"我们使用 `tools/train.py` 进行模型微调:\n",
|
|||
|
"\n",
|
|||
|
"```\n",
|
|||
|
"python tools/train.py ${CONFIG_FILE} [optional arguments]\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"如果你希望指定训练过程中相关文件的保存位置,可以增加一个参数 `--work_dir ${YOUR_WORK_DIR}`.\n",
|
|||
|
"\n",
|
|||
|
"通过增加参数 `--seed ${SEED}`,设置随机种子以保证结果的可重复性,而参数 `--deterministic`则会启用 cudnn 的确定性选项,进一步保证可重复性,但可能降低些许效率。\n",
|
|||
|
"\n",
|
|||
|
"这里我们使用 `MobileNetV2` 和数据集 `CatsDogsDataset` 作为示例"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "gbFGR4SBRUYN",
|
|||
|
"outputId": "2472d8ef-d2ae-4d20-f9ac-0c94ab62df8a"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:27: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
|
|||
|
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
|
|||
|
"2021-07-27 03:02:36,230 - mmcls - INFO - Environment info:\n",
|
|||
|
"------------------------------------------------------------\n",
|
|||
|
"sys.platform: linux\n",
|
|||
|
"Python: 3.7.11 (default, Jul 3 2021, 18:01:19) [GCC 7.5.0]\n",
|
|||
|
"CUDA available: True\n",
|
|||
|
"GPU 0: Tesla T4\n",
|
|||
|
"CUDA_HOME: /usr/local/cuda\n",
|
|||
|
"NVCC: Build cuda_11.0_bu.TC445_37.28845127_0\n",
|
|||
|
"GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
|
|||
|
"PyTorch: 1.9.0+cu102\n",
|
|||
|
"PyTorch compiling details: PyTorch built with:\n",
|
|||
|
" - GCC 7.3\n",
|
|||
|
" - C++ Version: 201402\n",
|
|||
|
" - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n",
|
|||
|
" - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n",
|
|||
|
" - OpenMP 201511 (a.k.a. OpenMP 4.5)\n",
|
|||
|
" - NNPACK is enabled\n",
|
|||
|
" - CPU capability usage: AVX2\n",
|
|||
|
" - CUDA Runtime 10.2\n",
|
|||
|
" - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70\n",
|
|||
|
" - CuDNN 7.6.5\n",
|
|||
|
" - Magma 2.5.2\n",
|
|||
|
" - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=10.2, CUDNN_VERSION=7.6.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n",
|
|||
|
"\n",
|
|||
|
"TorchVision: 0.10.0+cu102\n",
|
|||
|
"OpenCV: 4.1.2\n",
|
|||
|
"MMCV: 1.3.9\n",
|
|||
|
"MMCV Compiler: n/a\n",
|
|||
|
"MMCV CUDA Compiler: n/a\n",
|
|||
|
"MMClassification: 0.13.0+899047a\n",
|
|||
|
"------------------------------------------------------------\n",
|
|||
|
"\n",
|
|||
|
"2021-07-27 03:02:36,231 - mmcls - INFO - Distributed training: False\n",
|
|||
|
"2021-07-27 03:02:36,458 - mmcls - INFO - Config:\n",
|
|||
|
"model = dict(\n",
|
|||
|
" type='ImageClassifier',\n",
|
|||
|
" backbone=dict(\n",
|
|||
|
" type='MobileNetV2',\n",
|
|||
|
" widen_factor=1.0,\n",
|
|||
|
" init_cfg=dict(\n",
|
|||
|
" type='Pretrained',\n",
|
|||
|
" checkpoint=\n",
|
|||
|
" 'checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth',\n",
|
|||
|
" prefix='backbone')),\n",
|
|||
|
" neck=dict(type='GlobalAveragePooling'),\n",
|
|||
|
" head=dict(\n",
|
|||
|
" type='LinearClsHead',\n",
|
|||
|
" num_classes=2,\n",
|
|||
|
" in_channels=1280,\n",
|
|||
|
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
|
|||
|
" topk=(1, )))\n",
|
|||
|
"dataset_type = 'ImageNet'\n",
|
|||
|
"img_norm_cfg = dict(\n",
|
|||
|
" mean=[124.508, 116.05, 106.438], std=[58.577, 57.31, 57.437], to_rgb=True)\n",
|
|||
|
"train_pipeline = [\n",
|
|||
|
" dict(type='LoadImageFromFile'),\n",
|
|||
|
" dict(type='RandomResizedCrop', size=224),\n",
|
|||
|
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
|
|||
|
" dict(\n",
|
|||
|
" type='Normalize',\n",
|
|||
|
" mean=[123.675, 116.28, 103.53],\n",
|
|||
|
" std=[58.395, 57.12, 57.375],\n",
|
|||
|
" to_rgb=True),\n",
|
|||
|
" dict(type='ImageToTensor', keys=['img']),\n",
|
|||
|
" dict(type='ToTensor', keys=['gt_label']),\n",
|
|||
|
" dict(type='Collect', keys=['img', 'gt_label'])\n",
|
|||
|
"]\n",
|
|||
|
"test_pipeline = [\n",
|
|||
|
" dict(type='LoadImageFromFile'),\n",
|
|||
|
" dict(type='Resize', size=(256, -1)),\n",
|
|||
|
" dict(type='CenterCrop', crop_size=224),\n",
|
|||
|
" dict(\n",
|
|||
|
" type='Normalize',\n",
|
|||
|
" mean=[123.675, 116.28, 103.53],\n",
|
|||
|
" std=[58.395, 57.12, 57.375],\n",
|
|||
|
" to_rgb=True),\n",
|
|||
|
" dict(type='ImageToTensor', keys=['img']),\n",
|
|||
|
" dict(type='Collect', keys=['img'])\n",
|
|||
|
"]\n",
|
|||
|
"data = dict(\n",
|
|||
|
" samples_per_gpu=32,\n",
|
|||
|
" workers_per_gpu=2,\n",
|
|||
|
" train=dict(\n",
|
|||
|
" type='ImageNet',\n",
|
|||
|
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
|
|||
|
" pipeline=[\n",
|
|||
|
" dict(type='LoadImageFromFile'),\n",
|
|||
|
" dict(type='RandomResizedCrop', size=224),\n",
|
|||
|
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
|
|||
|
" dict(\n",
|
|||
|
" type='Normalize',\n",
|
|||
|
" mean=[123.675, 116.28, 103.53],\n",
|
|||
|
" std=[58.395, 57.12, 57.375],\n",
|
|||
|
" to_rgb=True),\n",
|
|||
|
" dict(type='ImageToTensor', keys=['img']),\n",
|
|||
|
" dict(type='ToTensor', keys=['gt_label']),\n",
|
|||
|
" dict(type='Collect', keys=['img', 'gt_label'])\n",
|
|||
|
" ],\n",
|
|||
|
" classes='data/cats_dogs_dataset/classes.txt'),\n",
|
|||
|
" val=dict(\n",
|
|||
|
" type='ImageNet',\n",
|
|||
|
" data_prefix='data/cats_dogs_dataset/val_set/val_set',\n",
|
|||
|
" ann_file='data/cats_dogs_dataset/val.txt',\n",
|
|||
|
" pipeline=[\n",
|
|||
|
" dict(type='LoadImageFromFile'),\n",
|
|||
|
" dict(type='Resize', size=(256, -1)),\n",
|
|||
|
" dict(type='CenterCrop', crop_size=224),\n",
|
|||
|
" dict(\n",
|
|||
|
" type='Normalize',\n",
|
|||
|
" mean=[123.675, 116.28, 103.53],\n",
|
|||
|
" std=[58.395, 57.12, 57.375],\n",
|
|||
|
" to_rgb=True),\n",
|
|||
|
" dict(type='ImageToTensor', keys=['img']),\n",
|
|||
|
" dict(type='Collect', keys=['img'])\n",
|
|||
|
" ],\n",
|
|||
|
" classes='data/cats_dogs_dataset/classes.txt'),\n",
|
|||
|
" test=dict(\n",
|
|||
|
" type='ImageNet',\n",
|
|||
|
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
|
|||
|
" ann_file='data/cats_dogs_dataset/test.txt',\n",
|
|||
|
" pipeline=[\n",
|
|||
|
" dict(type='LoadImageFromFile'),\n",
|
|||
|
" dict(type='Resize', size=(256, -1)),\n",
|
|||
|
" dict(type='CenterCrop', crop_size=224),\n",
|
|||
|
" dict(\n",
|
|||
|
" type='Normalize',\n",
|
|||
|
" mean=[123.675, 116.28, 103.53],\n",
|
|||
|
" std=[58.395, 57.12, 57.375],\n",
|
|||
|
" to_rgb=True),\n",
|
|||
|
" dict(type='ImageToTensor', keys=['img']),\n",
|
|||
|
" dict(type='Collect', keys=['img'])\n",
|
|||
|
" ],\n",
|
|||
|
" classes='data/cats_dogs_dataset/classes.txt'))\n",
|
|||
|
"evaluation = dict(\n",
|
|||
|
" interval=1, metric='accuracy', metric_options=dict(topk=(1, )))\n",
|
|||
|
"optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
|
|||
|
"optimizer_config = dict(grad_clip=None)\n",
|
|||
|
"lr_config = dict(policy='step', step=[1])\n",
|
|||
|
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
|
|||
|
"checkpoint_config = dict(interval=1)\n",
|
|||
|
"log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])\n",
|
|||
|
"dist_params = dict(backend='nccl')\n",
|
|||
|
"log_level = 'INFO'\n",
|
|||
|
"load_from = None\n",
|
|||
|
"resume_from = None\n",
|
|||
|
"workflow = [('train', 1)]\n",
|
|||
|
"work_dir = 'work_dirs/mobilenet_v2_1x_cats_dogs'\n",
|
|||
|
"gpu_ids = range(0, 1)\n",
|
|||
|
"\n",
|
|||
|
"2021-07-27 03:02:36,459 - mmcls - INFO - Set random seed to 0, deterministic: True\n",
|
|||
|
"2021-07-27 03:02:36,629 - mmcv - INFO - load backbone in model from: checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
|
|||
|
"Use load_from_local loader\n",
|
|||
|
"2021-07-27 03:02:48,486 - mmcls - INFO - Start running, host: root@13385ee59ab7, work_dir: /content/mmclassification/work_dirs/mobilenet_v2_1x_cats_dogs\n",
|
|||
|
"2021-07-27 03:02:48,487 - mmcls - INFO - Hooks will be executed in the following order:\n",
|
|||
|
"before_run:\n",
|
|||
|
"(VERY_HIGH ) StepLrUpdaterHook \n",
|
|||
|
"(NORMAL ) CheckpointHook \n",
|
|||
|
"(NORMAL ) EvalHook \n",
|
|||
|
"(VERY_LOW ) TextLoggerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"before_train_epoch:\n",
|
|||
|
"(VERY_HIGH ) StepLrUpdaterHook \n",
|
|||
|
"(NORMAL ) EvalHook \n",
|
|||
|
"(LOW ) IterTimerHook \n",
|
|||
|
"(VERY_LOW ) TextLoggerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"before_train_iter:\n",
|
|||
|
"(VERY_HIGH ) StepLrUpdaterHook \n",
|
|||
|
"(NORMAL ) EvalHook \n",
|
|||
|
"(LOW ) IterTimerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"after_train_iter:\n",
|
|||
|
"(ABOVE_NORMAL) OptimizerHook \n",
|
|||
|
"(NORMAL ) CheckpointHook \n",
|
|||
|
"(NORMAL ) EvalHook \n",
|
|||
|
"(LOW ) IterTimerHook \n",
|
|||
|
"(VERY_LOW ) TextLoggerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"after_train_epoch:\n",
|
|||
|
"(NORMAL ) CheckpointHook \n",
|
|||
|
"(NORMAL ) EvalHook \n",
|
|||
|
"(VERY_LOW ) TextLoggerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"before_val_epoch:\n",
|
|||
|
"(LOW ) IterTimerHook \n",
|
|||
|
"(VERY_LOW ) TextLoggerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"before_val_iter:\n",
|
|||
|
"(LOW ) IterTimerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"after_val_iter:\n",
|
|||
|
"(LOW ) IterTimerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"after_val_epoch:\n",
|
|||
|
"(VERY_LOW ) TextLoggerHook \n",
|
|||
|
" -------------------- \n",
|
|||
|
"2021-07-27 03:02:48,487 - mmcls - INFO - workflow: [('train', 1)], max: 2 epochs\n",
|
|||
|
"2021-07-27 03:03:03,247 - mmcls - INFO - Epoch [1][100/201]\tlr: 1.000e-02, eta: 0:00:44, time: 0.147, data_time: 0.054, memory: 1709, loss: 0.7429\n",
|
|||
|
"2021-07-27 03:03:15,110 - mmcls - INFO - Epoch [1][200/201]\tlr: 1.000e-02, eta: 0:00:26, time: 0.118, data_time: 0.025, memory: 1709, loss: 0.4858\n",
|
|||
|
"2021-07-27 03:03:15,134 - mmcls - INFO - Saving checkpoint at 1 epochs\n",
|
|||
|
"[ ] 0/1601, elapsed: 0s, ETA:[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)\n",
|
|||
|
"[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)\n",
|
|||
|
"[>>] 1601/1601, 282.3 task/s, elapsed: 6s, ETA: 0s2021-07-27 03:03:20,872 - mmcls - INFO - Epoch(val) [1][51]\taccuracy_top-1: 82.3235\n",
|
|||
|
"2021-07-27 03:03:35,108 - mmcls - INFO - Epoch [2][100/201]\tlr: 1.000e-03, eta: 0:00:13, time: 0.142, data_time: 0.046, memory: 1709, loss: 0.3329\n",
|
|||
|
"2021-07-27 03:03:47,285 - mmcls - INFO - Epoch [2][200/201]\tlr: 1.000e-03, eta: 0:00:00, time: 0.122, data_time: 0.030, memory: 1709, loss: 0.3197\n",
|
|||
|
"2021-07-27 03:03:47,307 - mmcls - INFO - Saving checkpoint at 2 epochs\n",
|
|||
|
"[>>] 1601/1601, 271.9 task/s, elapsed: 6s, ETA: 0s2021-07-27 03:03:53,269 - mmcls - INFO - Epoch(val) [2][51]\taccuracy_top-1: 92.1924\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"!python tools/train.py \\\n",
|
|||
|
" configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py \\\n",
|
|||
|
" --work-dir work_dirs/mobilenet_v2_1x_cats_dogs \\\n",
|
|||
|
" --seed 0 \\\n",
|
|||
|
" --deterministic"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "m_ZSkwB5Rflb"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"### 测试模型\n",
|
|||
|
"\n",
|
|||
|
"使用 `tools/test.py` 对模型进行测试:\n",
|
|||
|
"\n",
|
|||
|
"```\n",
|
|||
|
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"这里有一些可选参数可以进行配置:\n",
|
|||
|
"\n",
|
|||
|
"- `--metrics`: 评价方式,这依赖于数据集,比如准确率 a\n",
|
|||
|
"- `--metric-options`: 对于评估过程的自定义操作,如 topk=1.\n",
|
|||
|
"\n",
|
|||
|
"更多细节请参看 `tools.test.py` 。\n",
|
|||
|
"\n",
|
|||
|
"这里依然使用示例 `MobileNetV2`."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "Zd4EM00QRtyc",
|
|||
|
"outputId": "799899d3-1a90-40c0-cfdf-bbaa10ec4869"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:27: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
|
|||
|
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
|
|||
|
"Use load_from_local loader\n",
|
|||
|
"[>>] 2023/2023, 284.3 task/s, elapsed: 7s, ETA: 0s\n",
|
|||
|
"accuracy : 92.19\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"!python tools/test.py configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py work_dirs/mobilenet_v2_1x_cats_dogs/latest.pth --metrics=accuracy --metric-options=topk=1"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "IwThQkjaRwF7"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"### 推理计算\n",
|
|||
|
"\n",
|
|||
|
"我们用下面的命令进行推理计算并保存计算结果。\n",
|
|||
|
"\n",
|
|||
|
"```shell\n",
|
|||
|
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}]\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"可选参数:\n",
|
|||
|
"\n",
|
|||
|
"- `RESULT_FILE`: 输出结果的文件名。如果不指定,计算结果不会被保存。支持的格式包括json, pkl 和 yml\n",
|
|||
|
"\n",
|
|||
|
"这里依然使用示例 `MobileNetV2`."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "6GVKloPHR0Fn",
|
|||
|
"outputId": "01fb8ddc-a413-4e20-beb5-31a28b581e6d"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:27: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
|
|||
|
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n",
|
|||
|
"Use load_from_local loader\n",
|
|||
|
"[>>] 2023/2023, 289.1 task/s, elapsed: 7s, ETA: 0s\n",
|
|||
|
"dumping results to results.json\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"!python tools/test.py configs/mobilenet_v2/mobilenet_v2_1x_cats_dogs.py work_dirs/mobilenet_v2_1x_cats_dogs/latest.pth --out=results.json"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "G0NJI1s6e3FD"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"导出的json 文件中保存了所有样本的推理结果、分类结果和分类得分"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 372
|
|||
|
},
|
|||
|
"id": "HJdJeLUafFhX",
|
|||
|
"outputId": "0b9d3ef1-cc14-474f-d4df-4472611b0df2"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"class_scores: [0.9999998807907104, 1.671358518251509e-07]\n",
|
|||
|
"pred_class: cats\n",
|
|||
|
"pred_label: 0\n",
|
|||
|
"pred_score: 0.9999998807907104\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEYCAIAAABp9FyZAAEAAElEQVR4nJT96ZMkWXIfCKrqe88uv+KOvCuzju7qrmo00A00QZAcoZAznPmwIkP+p/NhZWVnZTAHSBALAmig+u6ursrKMzIyDj/teofqflB3S89qckTWpCTKM8Lc3ezZ0+unP1XF737+KQAgGEQDQCAoIgD04MGDul7f3Nx0fTsalbPZJCa/XC5dOQ4hGGMAoGka55y1tu/7LMu89865zWbz+PHj+Xy+2WxGo1G3rokIAJgZEa3N9L3C4Jwzxngfu64TkbIsq6pqu40xJqXU922MkYiMMSKS56Uh61zubGmtM5QjGkTMDZRl2fd9CKGoSkTx3gthSgkRBCGllFIKnAAAEbGXPM/1q7vOM3OM0ZBLKcWYUkrM/PHH33n58mWe5yklnlW3t7e5M6fHJ7PRaLOY3769zAxNsvz06PDJwwe5ofnV1eNHD2+uru7du3O7aeu6Nha97/7n//l/fvHimQ8dIi4Wi/V6vVk34/H0+voa0Xzy8XdfvXrVtt39+/c//vjjt2/fPn32TZ7nDBKZY4x37997+/btn/zpj9++fftX//E//vCHP1yv1yWS98FaK2i+//3vv3nztutDluWcQMjEwNfXN/cfPHrz5o0IVlVlQvfpp5+mlH72s5+NZ9MY48nZ6Xy5XK1Wn37/e3fu3Hnx6uVvf/vbPoQ7d+6cn5/317ezg8lqteq6phqVV1eXm83q0aNHTdOklFKSlASBrM2yLHM2ny9eiAgB5nleFBWhjZG7kMAUybibTXdxs1yHaMspOtf7aAiMMdZaRERhRHTOOUNd13FKIfQxRokJUay1xpibxTzGKCLWWiJCkizL8jy/vb1JKT169Kiu133fHx8f397etm07nU67rqubdex9nueTycQY5JjeXF4cTmfHx8fG4Pz2drVaTSaTO3furNb1YrHw3mdZllKq6xoRx+NxWZbM7JyLMW42G+dcVVUppc1mlee5916/6/T0tGmaZ8+e3bt3DxHLsgSAvu+dcyIyn8/J8PHxIVFs13Mi/z/+9/9qOrWzcX7n7lEKTQi9QLKICAAg8K2j7/uUEiKKSAih67rEIYRQ94uiKEII0+n06urq5OQkhAAAMUa9iJQSAFRVZa1dr9ciwszGGN36RDal5L1HIGYGABHRaxARXe5vXwrAVoaBmVl2B4CIiE/RGOO9DyGQNYMQhhCICA3q240xiGj0f4jW2rqui6Lquq4sS9/HoiiOjo5fvXo1mUx+/OMf379//2/+5m8ePnz4y9fPfvSjH/3t3/y1M/aPP//8F/+0qKpqeXuTCSBi0zTZeDSbzbquy/P86uqqmB5WVeUyU9dS1/XV1VVMvqqqs7Oz5XK5XC6rajydTlerze3tbUppPB4bY4qi0KUmIu9749xqtdKlOz8/J6I/+7M/CyE450Lb6foj4dOnT7OsCCGEEIu8stYSwsHBASBba5m5qoo337yq6/WdO3em0/Hd+3frun79+uXZ2Vni8PTpVwcH0zx3JydHTdN88+zrsspzgL7vjTHT6bQo87atQ+ibpgkhWGuzLBPBGFIIqWkakM7lGcfEzMzsvUdIPiYfOJ9UQGSMIWsNCxElkARCAqr4AACFRcR7j8LMzCnF6FNKFinLbFEUWZYdnZ40TdN1XQghhOBD1/e9bjYVifV6HUIYjUZElBeubVuBlGWZI6OLEH3o+x4AyrIUkfl8EUI4Pj4moouLCxa01qrCreu6russy0RkPB7Xda17O8syIkopxRiZue97773esu7YsixVjJ1zRDRsY2stS4wxWgPWWmt1JwMiJh9i4pQSS6Rhl8v7h/ceEVVy9r9YHwYijkYjABiNRnoPw14XEbWBjx8/zvNcf6NXP0ia3ttwV2ru9CnuiyUi6tbUW8L3D/3GuDtCivsHMydhERARRgAAIkLEBNLH4FNcbtYhJUHMigKtuV0uNm2zbuou+Murq+9//rnNsuvb2zzPz8/PP//88+l0aoxxzumNq6LRVTo6Ouq67ujoaL1e3717dzwep5SyLBuNRrPZjJlXq9VqtXLO6T7I81y1qd71crns+17dAedc13UoYoy5vLw8OTl5/vz5D37wg/Ozs6PDwzzPmVkfZmbp5cuXKSUU0MtAxLars8xJ4uOTQ/VN7t69+0//9E+67Z49e6ZKJ8Z4cHDw4ZMnP//5z5tNrdvx0f0Hq/miqde31zd92xhCjoEADdJ6ueqaNoVokMo8K/LMEkYfurZGREEQER9DCMHHwMyM0HVt27ad72OMMaXOt13X6T3GGL33IXh1Urz3XddlWWYt7R7rdqsw87Nnz54/f/7ixYu3b99u6lVKyVqbZS7LsizLdFPpNiADRJQ4OOemo/F0Oi2KAoD188fjMQC0bRtDyLJMDd1yuby+vo4xZllmjCGi8Xh8dnZ2584dAPDe13UdY9RtrIrDWqsbTLdTCAERp9Op6ik9Qfe8un4JpO9DCMHaLM9LvVRh8N7HGEUQwZCezcxpdwwibq0djUZlWaoa0C/QbeScCyGMx2MRybJMnQoR6ft+vV5fXV0h4ocffnh6ejoajVSRqIx1Xafr+L6PalWwU0qDJRzEbBBLRCQilVh98U4g3700QoiI+pCstcZYEBSGFDlF7r0XgN57Y20fvMszm7nRZFxUZe/94dFRluf/5e//7uPvfPKTP/9n63pzfHz8/Pnzf/2v//WDBw9+9atf6dqpsiyKIqWkznnTNKenp9PpVNVn0zSbzeby8lK90BDCb37zG2vt2dlZ3/d1XRdFoU+rbdurq6vlcpmEp9Pp0dGRcw4AiqLo2857//vf/q5erfURTMbjsiy7rh3WIYTgMiPCSGJAdspbPnz85OT0qOub2XTcdvXzF98kDin0PnQI7H2XGTo7OyGU2/n1Yn5jCI4OZ02z6ft+uVyq8VksFl3XuMxYR71v62a92azUThZFUZalc86nkFIMKaaUYkqIaDNXlCUQCoIxJitcVhY2z4xzLs8QkZm93+p0IrKWdD/oHTnnNMbRM+/cuXN0dKSO1aCRdds4ZxCxKAqXGR+6EEJd11VejMtK/UlriQCNMS6zR7ODvu/bth6NRqPRqGmaul7nuauqSpcxxjgej+/fv3/37t3pdHp9fb1er9u2TSk553RzWmuLotDtp0ZPhVDV7mCo9F7UzAij97Hvg6AGVsAhMbP3UZIYpNwVpHel4qcirnKoFl9XRKVFhTDP89VqhYjz+Xwymcznc0TUB6OGmJnrum6aRgVYTZBzTkVCpUtvYJCbQa50lQch3JdD3h3f0hr6sIczhRCAEI2xmTHGGAN7QktEYMgVeQKZHh4kEJtnPsWsKNQe5lXJCJfXV6t644ocDLVtu1gs1Dm/vr6u63qz2YiIOkjOOdU7RHRwcPDJJ58Q0ZMnTz755JOyLL/88stvvvkGAL773e/OZjO9AFVn6l8YY5yzIty2zXg8un//3tnZ6cHBjAhR+PDo4JunX2eZ+9//978cFcXLZ8/KLJtOp23b6uecHh2v18uUUgpxvV7H5HNrEofet5PJ+N6dO9Px+M2b13/ywz/65puvr64u//zPf3JyeHB5eVGV+Xqz/Onf/5c//qPPLSGBfP797/3iZ19Yg9aStSTAXdfU9TrGOBqNzs/PsyyLMa5Wq+Vy7n3nMluWeVnmiCiEYBAIhRhIyFrnXDmqqqoqRuVkNj04OJhOp7OD6eHxgbWGCNWl6vs+hF631mJ5u9lsvPcAbK3VHQIAR0dHBwcHk+moKDPdIQCit8/MbdsmDjHGuq5VIZZlqVYhRh99UIMxLisAcM6URWGtTSl436FAmRcff/zxeDxWpVMUxWQySSldXl7e3Nzoe3UXqQEnIjUY+vtB8FRpMrPamGG3ExGA+rEphMBJhJETxMgxcIoCQIho33dHh+BQ+r4vikz3SkoJgMkYRHRZpjAMEZ2cnCyXyyzLxuNxjHEw6HVdLxYL9bJ0rVXwjDEpSQhBRPKsGELBwdskIpb0LXdUzwkhgKAxKaWESAxRBEXEUhRJzFFEEggyi0hiFo0JrUE
|
|||
|
"text/plain": [
|
|||
|
"<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x280 at 0x7FB6DAD5F550>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import json\n",
|
|||
|
"\n",
|
|||
|
"with open(\"./results.json\", 'r') as f:\n",
|
|||
|
" results = json.load(f)\n",
|
|||
|
"\n",
|
|||
|
"# 展示第一张图片的结果信息\n",
|
|||
|
"print('class_scores:', results['class_scores'][0])\n",
|
|||
|
"print('pred_class:', results['pred_class'][0])\n",
|
|||
|
"print('pred_label:', results['pred_label'][0])\n",
|
|||
|
"print('pred_score:', results['pred_score'][0])\n",
|
|||
|
"Image.open('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"accelerator": "GPU",
|
|||
|
"colab": {
|
|||
|
"collapsed_sections": [],
|
|||
|
"name": "MMClassification_tools_cn.ipynb",
|
|||
|
"provenance": []
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.8"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|