mmpretrain/docs_zh-CN/tutorials/MMClassification_python_cn....

1124 lines
1.5 MiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "XjQxmm04iTx4"
},
"source": [
"<a href=\"https://colab.research.google.com/github/open-mmlab/mmclassification/blob/master/docs_zh-CN/tutorials/MMClassification_python_cn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UdMfIsMpiODD"
},
"source": [
"# 基于 Colab 的 MMClassification Python 教程\n",
"\n",
"在本教程中会介绍如下内容:\n",
"\n",
"* 如何安装 MMCls\n",
"* 如何基于预训练模型进行推理计算\n",
"* 如何基于预训练模型进行模型微调 "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iOl0X9UEiRvE"
},
"source": [
"## 安装 MMClassification\n",
"\n",
"在使用 MMClassification 之前,我们需要配置环境,步骤如下:\n",
"\n",
"- 安装 Python, CUDA, C/C++ compiler 和 git\n",
"- 安装 PyTorch (CUDA 版)\n",
"- 安装 mmcv\n",
"- 克隆 mmcls github 代码库然后安装\n",
"\n",
"因为我们在 Google Colab 进行实验Colab 已经帮我们完成了基本的配置,我们可以直接跳过前面两个步骤 。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XjQxmm04iTx4"
},
"source": [
"### 检查环境"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c6MbAw10iUJI",
"outputId": "d2afec96-2fea-4dfc-bb44-79e49b979717"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content\n"
]
}
],
"source": [
"%cd /content"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4IyFL3MaiYRu",
"outputId": "a9b9015d-8ae7-4a5b-e935-909e78033028"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content\n"
]
}
],
"source": [
"!pwd"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DMw7QwvpiiUO",
"outputId": "0ddc0277-47a9-4698-b6f2-d1e2a5e00f78"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nvcc: NVIDIA (R) Cuda compiler driver\n",
"Copyright (c) 2005-2020 NVIDIA Corporation\n",
"Built on Wed_Jul_22_19:09:09_PDT_2020\n",
"Cuda compilation tools, release 11.0, V11.0.221\n",
"Build cuda_11.0_bu.TC445_37.28845127_0\n"
]
}
],
"source": [
"# 检查 nvcc 版本\n",
"!nvcc -V"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4VIBU7Fain4D",
"outputId": "3201af41-2190-40fc-cfea-688d5a551998"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0\n",
"Copyright (C) 2017 Free Software Foundation, Inc.\n",
"This is free software; see the source for copying conditions. There is NO\n",
"warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
"\n"
]
}
],
"source": [
"# 检查 GCC 版本\n",
"!gcc --version"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "24lDLCqFisZ9",
"outputId": "db2e6bf1-cefa-4aa9-c303-15d48477d03b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.9.0+cu102\n",
"True\n"
]
}
],
"source": [
"# 检查 PyTorch 的安装情况\n",
"import torch, torchvision\n",
"print(torch.__version__)\n",
"print(torch.cuda.is_available())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2aZNLUwizBs"
},
"source": [
"### 安装 MMCV\n",
"\n",
"MMCV 是 OpenMMLab 代码库的基础库。Linux 环境的安装 whl 包已经提前打包好,大家可以直接下载安装。\n",
"\n",
"需要注意 PyTorch 和 CUDA 版本,确保能够正常安装。\n",
"\n",
"在前面的步骤中,我们输出了环境中 CUDA 和 PyTorch 的版本,分别是 11.0 和 1.9.0,我们需要选择相应的 MMCV 版本。\n",
"\n",
"另外,也可以安装完整版的 MMCV-full它包含所有的特性以及丰富的开箱即用的 CUDA 算子。需要注意的是完整版本可能需要更长时间来编译。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nla40LrLi7oo",
"outputId": "61e90f5e-023b-4a03-ddb8-760add09e868"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in links: https://download.openmmlab.com/mmcv/dist/cu110/torch1.9.0/index.html\n",
"Collecting mmcv\n",
" Downloading mmcv-1.3.9.tar.gz (313 kB)\n",
"\u001b[K |████████████████████████████████| 313 kB 7.8 MB/s \n",
"\u001b[?25hCollecting addict\n",
" Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcv) (1.19.5)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from mmcv) (7.1.2)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from mmcv) (3.13)\n",
"Collecting yapf\n",
" Downloading yapf-0.31.0-py2.py3-none-any.whl (185 kB)\n",
"\u001b[K |████████████████████████████████| 185 kB 13.6 MB/s \n",
"\u001b[?25hBuilding wheels for collected packages: mmcv\n",
" Building wheel for mmcv (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for mmcv: filename=mmcv-1.3.9-py2.py3-none-any.whl size=451832 sha256=21838bb360585d5fb846ab545d87b70c262e22bcf79c680ab48bfc922ab5b5b3\n",
" Stored in directory: /root/.cache/pip/wheels/88/48/bf/655e136aea5534d7a9a85fe247fee7957178fc19cf79dda602\n",
"Successfully built mmcv\n",
"Installing collected packages: yapf, addict, mmcv\n",
"Successfully installed addict-2.4.0 mmcv-1.3.9 yapf-0.31.0\n"
]
}
],
"source": [
"# 安装 mmcv\n",
"!pip install mmcv -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.9.0/index.html\n",
"# !pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.9.0/index.html"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GDTUrYvXjlRb"
},
"source": [
"### 克隆并安装 MMCls\n",
"\n",
"接着,我们从 github 上克隆下 mmcls 最新代码库并进行安装。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bwme6tWHjl5s",
"outputId": "4094e1d1-0a3d-4a44-b0f4-7a75e953aa82"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cloning into 'mmclassification'...\n",
"remote: Enumerating objects: 3161, done.\u001b[K\n",
"remote: Counting objects: 100% (12/12), done.\u001b[K\n",
"remote: Compressing objects: 100% (12/12), done.\u001b[K\n",
"remote: Total 3161 (delta 2), reused 5 (delta 0), pack-reused 3149\u001b[K\n",
"Receiving objects: 100% (3161/3161), 2.81 MiB | 29.39 MiB/s, done.\n",
"Resolving deltas: 100% (2039/2039), done.\n"
]
}
],
"source": [
"# 下载 mmcls 代码库\n",
"!git clone https://github.com/open-mmlab/mmclassification.git"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iJ45llP7jr5a",
"outputId": "c158c88c-5453-409a-aed9-1f6e219a6d17"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content/mmclassification\n"
]
}
],
"source": [
"%cd mmclassification/"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7c9K8ZVCjuFy",
"outputId": "d651cdbd-4457-4ffa-dc5f-f274570c0fd9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"configs docs_zh-CN model-index.yml requirements.txt tests\n",
"demo\t LICENSE README.md resources\t tools\n",
"docker\t MANIFEST.in README_zh-CN.md setup.cfg\n",
"docs\t mmcls\t requirements setup.py\n"
]
}
],
"source": [
"!ls"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "djjeq0I3jwOL",
"outputId": "0df874d1-4052-43b6-878c-0ea3c36fd7fc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Obtaining file:///content/mmclassification\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from mmcls==0.13.0) (3.2.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mmcls==0.13.0) (1.19.5)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.13.0) (2.4.7)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.13.0) (2.8.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.13.0) (1.3.1)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mmcls==0.13.0) (0.10.0)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->mmcls==0.13.0) (1.15.0)\n",
"Installing collected packages: mmcls\n",
" Running setup.py develop for mmcls\n",
"Successfully installed mmcls-0.13.0\n"
]
}
],
"source": [
"# 从源码安装 MMClassification\n",
"!pip install -e . "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hFg_oSG4j3zB",
"outputId": "3614ff26-ff64-45e9-ae30-b3d91dccea2c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.13.0\n"
]
}
],
"source": [
"# 检查 MMClassification 的安装情况\n",
"import mmcls\n",
"print(mmcls.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4Mi3g6yzj96L"
},
"source": [
"## 使用 MMCls 预训练模型\n",
"\n",
"MMCls 提供很多预训练好的模型,可以访问链接查看[模型库](https://github.com/open-mmlab/mmclassification/blob/master/docs/model_zoo.md).\n",
"这些模型都已经在 ImageNet 数据集上获得了 state-of-the-art 的结果。\n",
"我们能够直接使用这些模型进行推理计算。\n",
"\n",
"在使用预训练模型之前,我们需要进行如下操作:\n",
"\n",
"- 准备模型\n",
" - 准备 config 配置文件 \n",
" - 准备模型权重参数文件\n",
"- 构建模型\n",
"- 进行推理计算"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nDQchz8CkJaT",
"outputId": "8b7a96a8-2f8a-468e-98cf-9e5b27371f91"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-07-27 03:22:18-- https://www.dropbox.com/s/k5fsqi6qha09l1v/banana.png?dl=0\n",
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6031:18::a27d:5112\n",
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n",
"HTTP request sent, awaiting response... 301 Moved Permanently\n",
"Location: /s/raw/k5fsqi6qha09l1v/banana.png [following]\n",
"--2021-07-27 03:22:18-- https://www.dropbox.com/s/raw/k5fsqi6qha09l1v/banana.png\n",
"Reusing existing connection to www.dropbox.com:443.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://ucb1b407c7a3eea2252b62b37824.dl.dropboxusercontent.com/cd/0/inline/BTF_U7peGq3OToFIhecJylaLgb14wf6IggUQEZYQ4Ri10lADalCUOY9UFDHwcjQtktPhZybKPJgx9AZ1mtacFOBdsiTI8nUMo12G4-3QVYlxjbi68cg2gv0N7zE8ckI8Avchd-ZAws7xMqliw3ePJx6o/file# [following]\n",
"--2021-07-27 03:22:18-- https://ucb1b407c7a3eea2252b62b37824.dl.dropboxusercontent.com/cd/0/inline/BTF_U7peGq3OToFIhecJylaLgb14wf6IggUQEZYQ4Ri10lADalCUOY9UFDHwcjQtktPhZybKPJgx9AZ1mtacFOBdsiTI8nUMo12G4-3QVYlxjbi68cg2gv0N7zE8ckI8Avchd-ZAws7xMqliw3ePJx6o/file\n",
"Resolving ucb1b407c7a3eea2252b62b37824.dl.dropboxusercontent.com (ucb1b407c7a3eea2252b62b37824.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f\n",
"Connecting to ucb1b407c7a3eea2252b62b37824.dl.dropboxusercontent.com (ucb1b407c7a3eea2252b62b37824.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 297299 (290K) [image/png]\n",
"Saving to: demo/banana.png\n",
"\n",
"demo/banana.png 100%[===================>] 290.33K --.-KB/s in 0.07s \n",
"\n",
"2021-07-27 03:22:18 (4.17 MB/s) - demo/banana.png saved [297299/297299]\n",
"\n"
]
}
],
"source": [
"# 获取示例图片\n",
"!wget https://www.dropbox.com/s/k5fsqi6qha09l1v/banana.png?dl=0 -O demo/banana.png"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 420
},
"id": "o2eiitWnkQq_",
"outputId": "9de00c3e-e70a-40f1-d68c-6a973ad964b8"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAGTCAYAAADdkO5AAAABd2lDQ1BJQ0MgUHJvZmlsZQAAeJx1kc0rRFEYxn8zaDA0CxaTqLsYsqAmSpaMhc0kDcpgM3PdmVF3Zm733kmTrbKxUBZi42vhP2CrbCmlSEkW/gJfG+l6j6tG4tzOfX895zxv5zwHgklTLzr1cSiWXDs1ntBm03Na6JEwUZroojmjO9bo5GSSf8fbNQFVr/pVr//3/TnCi4ajQ6BReEi3bFd4RDi57FqK14Xb9UJmUXhPuM+WAwqfKz3r84PivM8viu3p1BgEVU8t/4OzP1gv2EXhXuFY0azo3+dRN2kxSjNTUjtkduKQYpwEGlkqLGHi0i+1JJn97Yt/+SYoi0eXv0UVWxx5CuLtE7UiXQ2pOdEN+UyqKvffeTq5wQG/e0sCGu4977kbQpvwseF57/ue93EAdXdwWqr5y5LT8KvoGzUttguRVTg+q2nZLThZg+itlbEzX1KdzGAuB09H0JqGtktonvez+l7n8AamV+SJLmB7B3pkf2ThE7z6Z+tvc+SlAAEAAElEQVR4nLz9S6ws25aeh31jPiIiM9dae+/zuHWqWHy4YViCDD1sqCfZXUvuGLDbBqG2AUN9wbYoWDbctiULBGSxRcA2BJMQIIsPyZRokiWbpkULIGmKZJGsB+s+zjl7r7UyI2I+hhtjzMi1q27de9lxAvfuc85eKzMyYs4xx/jH//9D/jf/xr+hqkpKiQCoKACqQMf+WUDU/mOIgS7Cy7YSEHJKQEPVfi+lhKrSWiOmBEDvnVYrMQZEAhICqoqqMsWJLopqp/d+/PecM/teuF2fmeYTKUYQEASAnDO9d2rvpBgRkeMaemt24SLUUvFvgYggQYgx0fZKRzlfztRaj2vf953q12r3QUHv/zzPma56/HsIgVorH58/Mk0zT49P3K43WqucTjMxZlQ727bbJSHEmBEBbY2YJ0IM7PuKANu2Mc9nUk60WgkxECSw7zv7euPp/XtEOkqA1thrI4To96OiQK2V1hopJfseChICQkCCECTY8xUopYAqIoHeO3my+xpCIIRIErvq0hoShPM8232UYD8TE7V3tHVOy4k4TYjA+Xzi6emJx8cHPnz5FX/gV/8AMUZ+kdfLyws//OEP+fjtt3z/7Sdu2yu9d24vK6VvtNYQVVTEn2xn3wt9b4QYEHv09CC0pkgMBIRt3UDU/q53YhRyTrTW6N3eqew7L88vnE4nnh4utNrRCEFhb50kgZgSVZSgQkqzr/8GCAElxsS+7dz2wnle2PaNT8/P/PKv/AohCrVX1utKq0rKiXmaeb4+Q+vMy4k8Z6RDKZWQo312Kcwpk1Km9krviiSh10qtnXma7P6KEHx9AlRRtm1HBKaY6QLSlZQzYl8aiYkQ7vtnmiZCsDXfWiP6/sohoCH4v0NtHREhpYiqEkNkmjJVO1o7RHvPJJE8TYgIKASUbz9+ZF1XHh8fOJ8eUNrxmfu2IyIsp4uvNaH3SmudEKB3ECBPEzEkgtj6jjEiQWi+n0O0ddpRtFZASFMmIKgIOWdCgADUrrTWOU0TIWcC2PX4GhP/c7yCx7DeGhJsL4x7FWOk+71trR2/O2KUqvp36mhrNBS6stdCKYWy7fz4xz+m0ZnzBAp9L7R+fxbX2xVBOJ1OlFKP+LauKwpHTOwiKIB2QIgx0Hujq1LKTlCIOUFXeu+klECh9QaqpJhI+HpSVZq/mQQBxA4HEQsi/iWbKoIwe5DOObGu5biZtd0ftmCbcUSk1pQYleSBtTelR+zvVQkiEGyT11oRgdP5kRiE1hVRqN0OgJSz/emHx7j5IxD11izYRVs8tTZUO9SOhk6jMeXpeHghhCM4igRyTpTS6K2Rc0aCHThiEcYCRa120E0TX3/1tR8oSi0bwb9vSoLfPcS/ateGNqXWjTkKMU7HwpvnxT67FnpT2/QRTqcz0zSjKvSuxASSM4lAFLGg6QtYxD6rtUYIgiBI78QY0K4Q/XF1u167Z4Jqp1U7DLR3tCuaMzEEcgwExTdDp/dCnjKIIBIJMYAnByEEAtC126NHWdeVy+Xy+x4Mb1/zPJOmTMoTec6UNrOuK1vbLUV4s05UlUAgp0xtSggRkW5rTUFipGoHDwqt24bKORGzBTchEIOfmjGxLDMxRvbWqa0yS7aNJ4GYImoLGySi2u3AR2z9KpRW+fj8id6V05Rp2ii90nslhUQvtra72MZWIBDpMTDlcU8tNwlqz3PKmTBltlJopZCiHSSEyDRFJAWkC006SMBCXyeMRdGgS7cgE4M9X18nQe4JzzRN94TPA5+qkv0AEt/bdj+TPWsRamsQLQaklOihUWtjSokQoq89SzgrcDmfuTxejmRDgiAKEoR0udi+ShEloCgpzah2SinkHEgpEmMiSLDEIIgnNoERzlOKdJSMIPNClICI0rvSaYjfj94VEWXOCUmRI3r5ATViytsgf+yzEMgeiwA6SgxCwOLJ+J0RY1qzmDIS6AYEAk06np8TUuR0PlNKsaghEKbE9z/8no+fPpJy5osPX7CcFkiJFAPURmsdBcq2sXVlOp8s6LdGo5FiomsjihDBf14JmmiiMBKDbrEzxUjrjYRXCdKVLrahA4Hu2SVd7QEKiAqtFWKyLxzxgBmD3TAFDQLaUcQybhFbaKidbiHQ+ng4FvBH4vP2xA3BFnKIdi0528kYWgDtnu1Hu/G9WUDzhzEWPP6nVUqRsnfLpFq3xRWCZR1iQbG2RoqWDanagdE9EARJLMtC1+pZOV5FQfPDJYZIa42cZ6aUaKgdhl1JMdimUyXEiPaOyIygKN2Cut8XO1ACKVmWFrBDwLI9UKzyWm83luVCTMk2/V4s4IRAmCZSjHRt9m4x2e+OHeQxMQW7v6qdrkqyiGEb5E01NaVkvwBe8XkV0hp5ORE1oNIptxshRup5prXOuq2s68qnlxeWZfmFq4kUoh1wggcEYZkXRJR9345n7Q8A1YaqBdXeGnutlgSIoq2xdft3W/gV6bboguLr3tZvj5GH0wUJQhdlSQmNYs/AK5KO0iVTS6XrTkqJ2pU5+d/XwjTPjFiTc+KyLERVmidMp2lBZCdmC3Q5JcSfd0OprSAdJNoKERG0Nsq2U1qzA4lAium4B10gqPgeVjqdqp0lT4gq120jq11Po5M12GGrSsz5ntyJJSIShRCjrXM/mHNKxx7r3RLKFCIpRCQmJDia0DM59iOIxxgtadRKa/D4/j2N7vtlonvFP00JkWjLUwK9dVpvVi1IYmTMy3Lye2CVW0qRkKPVcqlZYA2Rpt0qJrH9CZ3elYb4+hFLgPLs+0g9CbOKQN/EkLfVxPjvIwapWsWaQ7DqXSCKUN/8zqiKYkpH8hxHzOuWbEwp00OEp0e2dePj80dondO0MJ9O9O+/Y982QxhipvRqh4uv52VeALhdb0zR9tBeO1EC6/XK7bbyxRcfrIqaJiJehXUhJqErdI9HXazCSSEE1m0jiHA6LfSuHoTUHogHHXonADkm9t6JKYJ2WnW4B8sCqnZqV6IHectqhRADeZqOzAQRew8gBL84r1haa3S1Babw5jSGoMpWu5W4Djd1L6VSjJ+Vd+PU772RUmaaJ3qzDMc+0wLv2yqiqxIk0NWuK6ZALQZvpCURiKgotTXaXphO9lD2fSPniSSB0zwRQ2TvhVobMQRSytAbI08JDsuV0uhajmxERCyb9w1gSaEvQvW9KoHrdeO7n3zLV19HnvIDm9o15ZT84LRyN4gQgm1yBVopiESmnGxdxUjbi8cz8ViQoFtFYRvbNlaM90xKVdn2SoiRNC8QFNFITEKKgmL3rU/Ktq7k65Wy7cTz6ecfENEO5NNy4iW9UEolxkzOndY2C8p1VHXNn2MgCKz7yt52kMgpBnqzZCZKQAKUplxfrrxcX/ilH3xFSonmGzQS7CBuDUKi1U5R5ZwSEgy2sIDQvXy3QJZiIkhHgx0mOcBTzlZpdlub87IQUoLxXEJgipMd3L2ybiuP+QGRCLWgfr/lgFjF4JYUueRpFO50UQIGdFntYBUC2gkhEVs9Mt7JD4LeOzk
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=393x403 at 0x7F93A2CF9090>"
]
},
"execution_count": 13,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"from PIL import Image\n",
"Image.open('demo/banana.png')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sRfAui8EkTDX"
},
"source": [
"### 准备模型文件\n",
"\n",
"预训练模型通过配置文件和权重参数文件来定义。配置文件定义了模型结构,权重参数文件保存了训练好的模型参数。\n",
"\n",
"在 GitHub 上 MMCls 通过不同的页面来提供预训练模型。\n",
"比如, MobileNetV2 的配置文件和权重参数文件就在这个[链接](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2)下。\n",
"\n",
"我们在安装 mmcls 时就已经将配置文件拿到了本地,但我们还需要手动下载模型权重参数文件。方便起见我们将权重参数文件统一保存到 `checkpoints` 文件夹下. "
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "03RvRFuykb0C",
"outputId": "70385b7b-2cd9-4673-fbca-6637da435bb3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-07-27 03:22:19-- https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"Resolving download.openmmlab.com (download.openmmlab.com)... 47.88.36.78\n",
"Connecting to download.openmmlab.com (download.openmmlab.com)|47.88.36.78|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 14206911 (14M) [application/octet-stream]\n",
"Saving to: checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"\n",
"mobilenet_v2_batch2 100%[===================>] 13.55M 11.0MB/s in 1.2s \n",
"\n",
"2021-07-27 03:22:21 (11.0 MB/s) - checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth saved [14206911/14206911]\n",
"\n"
]
}
],
"source": [
"!mkdir checkpoints\n",
"!wget https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth -P checkpoints"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VvRoZpBGkgpC",
"outputId": "1078d0df-e4d7-4947-e848-44c15b8ed444"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py\n",
"checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n"
]
}
],
"source": [
"# 检查确保配置文件和参数文件都存在\n",
"!ls configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py\n",
"!ls checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eiYdsHoIkpD1"
},
"source": [
"### 图像分类\n",
"\n",
"MMCls 提供了 high level APIs 用来进行推理计算. \n",
"\n",
"首先,我们构建模型。"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KwJWlR2QkpiV",
"outputId": "4e897cde-f3f8-4c7e-de7b-4e04b5f22db8"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:27: UserWarning: Fail to import ``MultiScaleDeformableAttention`` from ``mmcv.ops.multi_scale_deform_attn``, You should install ``mmcv-full`` if you need this module. \n",
" warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Use load_from_local loader\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/content/mmclassification/mmcls/apis/inference.py:44: UserWarning: Class names are not saved in the checkpoint's meta data, use imagenet by default.\n",
" warnings.warn('Class names are not saved in the checkpoint\\'s '\n"
]
}
],
"source": [
"from mmcls.apis import inference_model, init_model, show_result_pyplot\n",
"\n",
"# 指明配置文件和权重参数文件的路径\n",
"config_file = 'configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py'\n",
"checkpoint_file = 'checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth'\n",
"# 指明设备,如果你没有开启 GPU可以使用 CPU `device='cpu'`.\n",
"device = 'cuda:0'\n",
"# device = 'cpu'\n",
"# 通过配置文件和权重参数文件构建模型\n",
"model = init_model(config_file, checkpoint_file, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GiSACYFgkvNE",
"outputId": "38445366-5bca-469a-a151-052510fe5d7d"
},
"outputs": [
{
"data": {
"text/plain": [
"(mmcls.models.classifiers.image.ImageClassifier,\n",
" mmcls.models.classifiers.base.BaseClassifier,\n",
" mmcv.runner.base_module.BaseModule,\n",
" torch.nn.modules.module.Module,\n",
" object)"
]
},
"execution_count": 17,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# 模型的继承关系\n",
"model.__class__.__mro__"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FyjY7hP9k0_D",
"outputId": "2f9c7e76-4350-44d4-a647-db5d26f8e2d7"
},
"outputs": [
{
"data": {
"text/plain": [
"{'pred_class': 'banana', 'pred_label': 954, 'pred_score': 0.9999284744262695}"
]
},
"execution_count": 18,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# 在单张图片上展示模型的分类效果\n",
"img = 'demo/banana.png'\n",
"import mmcv\n",
"img_array = mmcv.imread(img)\n",
"result = inference_model(model, img_array)\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 633
},
"id": "ndwdD8eUk96g",
"outputId": "abb722ef-4e71-4932-ed20-7876976bc378"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/content/mmclassification/mmcls/models/classifiers/base.py:221: UserWarning: show==False and out_file is not specified, only result image will be returned\n",
" warnings.warn('show==False and out_file is not specified, only '\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjoAAAJCCAYAAADJOW1sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9e9Rt11Uf9ptzrb3Pd+5Lb0uWbOthS5YfsuUXtsExBAqF0GFITAykJSTF5VEeaYFkmDSPkgAdoW7paAMtmBICbWoaaAipXQzExuBn7BpbsmxJlmTwU5ZkS/fq3vt95+y91uwf87HWPvdeWx6jGtFgfMvjWt85Zz/WY645f/O5SERw2A7bYTtsh+2wHbbD9uex8b/vDhy2w3bYDtthO2yH7bA9Xu0Q6By2w3bYDtthO2yH7c9tOwQ6h+2wHbbDdtgO22H7c9sOgc5hO2yH7bAdtsN22P7ctkOgc9gO22E7bIftsB22P7ftEOgctsN22A7bYTtsh+3PbXvcgA4RfSMR3UVE9xDR6x6v9xy2w3bYDtthO2yH7bBdqNHjUUeHiBKAuwF8PYBPAXgfgO8UkY/8//6yw3bYDtthO2yH7bAdtgu0x8ui8xUA7hGR+0RkC+CNAL7lcXrXYTtsh+2wHbbDdtgO23lbfpyeew2AT3afPwXgpRe6+OjRo3LJxRcDAIjIvt2xNH0pw5PdV2uNzxQ/yjnX6VfSPttHgXT3EbDTHeqvh8Rj+lsAQEQgtVo/aPEav2i3K/3rFj0/ZyqWX9A5f+kVRKT9DaudvkBEIOLj3J3vZQ92pieuEBtjrRVEBCYGEaFK1ffZu8mu9utbT3dH2uZWdCH0b+nnkGLsfr+PBRBwSsun2nqKXd/f77+1dd0ZO124n/7e8wwg5oZ2vlvcQwAT79y6fI9NIcjm1ftIzEjMYGb9OyUMw3DO/Y+11VoxzzNqrailoJQKkRp9rbXN72LMPV2IzqOInLcfu9tjlxb8KrJB+jXtRglaAwjEhMSMc5egUSjR7nsZIFl+6b/FntD/MnO8T0QwDMPOvm470NdIag16IrZ957QvdrsNTHmMrynZs6QRTnRTuj20nMN+tGLv8Oftjg3QoS++2LnI6b/NvSxod7E/zseinUbjI3U/7/IrOi+dVKlBb8wMpp3rfG5ssLT4vfGZ87VGW8sxkDPdXWJasvnFOJesfIe3dyTGShznPO6LfFi2BeF4R2hBS+fr4hd5YOOr3Wex/eW0XUox+bHks30fvH/Bj52fyXLVFzy/G8subcR7GuPvd/POKDs67Z70wAMP4uSpU+ed0McL6HzJRkTfC+B7AeDiiy7Cj3z/9wMAUkp2hTLXEH5CSwkiAJmwFiZlkwRsygzYRskpYZo2MZUAkLgRn9SKlFJs7CqCWgqcUbIJEonfBEMeQGwCsxMIxi0XAKDMBZx4uVmlCYxxtWr34FzCoq6fQYj2GxOj+vttHjglgIB5npFSRs4ZKSWUUhZMoZSCeS5gBlJilFJRiwKKxBkgoBYFMMOQFMywErIDG29zmRXosIKMs2fOACJIOSMPGcMwQkTsnXNbOyJ9pgC1zAABwzAi54y5qOBlIpRSwKzjqFKjX+zCnsiEdMUwDgAU6BIzCDDBrfPFbHNmG7uUEsKMiJBzCsFCAJhT0I2OPzgtpnnu5tT7yYoljRk6AEyJdb1sGzOAcRicHCAiSDmBUwZxQq36QyLCkDOQM5gJOWes91Y4euwY1us11keO4MRFJ3DZ5ZdjtVoh5y9vO4sINpsNvvDwwzhz+jTOPHoap06ewmazwTRNmOcZZ06fNjrTf0Hr3nkYrZeKOs0GzHSRxfcrEaqtDIN0f5SKKgXB6AhImeEIRUqN30QEZZ6xv38AZkJKGavVCtXomm2tq70wMQNMqLZfdA0GOE8RVDSRp/ecOXMG86x7//jRI9hMEw42B9hOE6668irkpLQzlxnb7YxSKkDKY8ZhxMF2Y/QtOLK3BiftA1VgO8/gvp+1Kv2mZHM3AzBlgW0cxiNSzmDrbSWBPqFXaJQWS6kgQaNhbsqBt+ARxt+ICCmloGOlV7J9rfScesBh16oAd1KgWINaq4J4ThhSApiCF9YOvDAxch6CF7jAcpqrqBjygJyy8SNbL3tnNZabmJHTANlVPqUu9nWyudb3J7DJAGKKPjkPAZQnLtCL8WVdQ9I9a/9LSZ/nCmWpEvOe86C0aDwl3tv96/fibqsmn/y6WqtSK+u6oTaAEgq+rYmvh8/DLniH7RdVbmajoYJpu8Xp02eazPBdUgs2+/s4ODjAZrMBp4QhDxjGAeMwIBk/q1IB47mzPc/5PjEjD4PxfZdnVWWHkTUbP5znGcX2Ux5GoFNGUieLXE4QEf7Wj/+dc+bQ2+MFdD4N4Knd56fYd9FE5JcA/BIAPOWaawSwvRNITpb8NC5of7pAKY7w4Zqe2OAR38d74zEd+nfwcYF4JSICSYcfO3UziNUmuxjBERHykENjcHBe4QTZAI+PxTfPOdrxwiqDEPAsbBpQCYEZ17hmbM8hombJsQ3bC+o2OH3+znQ3jcg1Gvt7TGPcq5tNon/eJ+nWx8fZ+lTRVrxitzGnAItSbPWoMQzmpCCCXasVY0wKTvU9HPMBLDVCNs2TrM+9Rno+xbWfj3Ose+KauizozP9iUiDEWK5HrQKqNRg6iNs1EEgpEGEwVwPk1Ta5ArXtNAWo/XItOyklBWKJAzwqOK4A5rbmRmvnswgwMYQUPqigVEBRqjRaIoLUimrWDGZWRm1rzoQA07uaLAEQYgWHJkilm2OK+aUQ8uL9EGXoVCtc6ayiCg8ZMCpSFfRK1bEYDfnYIRUwRUuqMtpqFg82QcbESJwgpMIP7HTQ7XGfL7PGCQFzLQqSiUAMkJhiFPsshtFRU+NgIi7YFFQ5iCIxvrOwAth8UhOGvle97f59oX890GFmJFK6hYGHUGRg2n4tSMxIpGPnWGuDngIMw4AhZwjrwJPtbYBAgrCUEbEphnqNCJnl194FxjzPJuAVcPQWIk4NDOocEGphCDlQY5s72Hs5lBYi3a/igJnaPOl+8ut8jA1wUrcXFvvC+NQu2NldH7F9Ee9j6JpfYN3aDlo2V7aC1/XrmFWJmKZJLTtmrWXWOS9zwZkzZ8Ap4eiRoxjHESllgAhCSr8KwGzkNs4yF1BSBZiN70Oq8o1aTIkijDk38LkzBrGxg6BW1d7I8CXa4wV03gfgRiK6HgpwvgPAX3tMd1ZpTEJ2hYkSnsoCI36Y9kcCsGrOqGYG7xmtiyDf/MbkajdZtVTTABszqgEMeIGad1svIP1zb751YMFiZm5RVwHZe3rBJ7VpsvbHgvn5tTq2ilJJCUZEp8E2mSj6UCIUdaVVEdOEMlw4M3O801nqkAe4dkShlTQ03dxf1N4FIOcRQ0ouF2yz6HNTUoZYbb51szahr0yvafKuBToYc2uBCmPTc33z2wpUUZAjUrE52GAc95ByDgCMuTSXIhHYrCCJOYQugcApB+CEg6FuhlzYqYWtNovcAh2S0bJZbc5hXKalic5TMctZXu3BXVaCink7q3CAoOyNqFW1+LTdYrvd4uz+PnJuFrwvtzGpAFIahVlN9PMwNDqY5yVIhNFqldKC/XRCUKVimmcMJmQgAGrFXFUJycbQ1HpR3RAXLpalbkOglLAitWeAgQrB4Jo3k1o6XBBGN5L2r1SITEhQAVdFGR+RapZUBXkYkDrlIiXGOGQkALywuhJGcxUKCJxTKFgEgLLOnRj9llKMZzVw4Isvc0WZ5rCeqGXBrNp+eRUsbxc0Z5egQoFNJgYScLDZ6B43604VAZOAbIXIlEDuLTlt2WJtOaUGnJUoOiCtd7h1lwyUJNaOEicQu3VeBXIVs8rwDsASgaCgFGDII1JOqBCzkKoVRuaidMqMnBlEqVM+SN0t1WnZwU5T6sZxXCi+zB24z9bHKs48AkiJ1ADDZPtDoPMZdsEAj0YDlGJulf1KjNm9A7vWnAuBnP46EQGlBtaS7ylUsAjKeYQ92br1ijwBSI6gawWZ0jmkjErV+ABjc7DB/sE+tts
"text/plain": [
"<Figure size 1080x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"# 可视化分类结果\n",
"show_result_pyplot(model, img, result)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oDMr3Bx_lESy"
},
"source": [
"## 模型微调\n",
"\n",
"模型微调是将预训练好的模型在特定的数据集上对模型参数进行非常精细调整的过程,最终让预训练的模型能够适配新的数据集及对应的任务。相比于模型的训练过程,模型微调大大降低了训练的时间,并减少了数据量很小的数据集在训练过程中会出现的过拟合问题。\n",
"\n",
"模型微调的基本步骤如下:\n",
"\n",
"1. 准备新数据集\n",
"2. 让数据集能够满足 MMCls 的要求\n",
"3. 根据数据集创建配置文件 \n",
"4. 进行训练和验证\n",
"\n",
"更多细节可以查看 [文档](https://github.com/open-mmlab/mmclassification/blob/master/docs/tutorials/new_dataset.md)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TJtKKwAvlHX_"
},
"source": [
"### 准备数据集并满足 MMCls 的要求\n",
"\n",
"这里我们下载猫狗分类数据集,详细过程过程参考 MMClassification 教程"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3vBfU8GGlFPS",
"outputId": "431863a6-d7da-4642-a6a0-95167f696832"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-07-27 03:22:41-- https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0\n",
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6031:18::a27d:5112\n",
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n",
"HTTP request sent, awaiting response... 301 Moved Permanently\n",
"Location: /s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip [following]\n",
"--2021-07-27 03:22:41-- https://www.dropbox.com/s/raw/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip\n",
"Reusing existing connection to www.dropbox.com:443.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://ucb4542ecde226f32d4315b1fbc8.dl.dropboxusercontent.com/cd/0/inline/BTGvHQfXo4i1Zuda0OfZIn3M0sLD9saAggO_ol8huDdBTN96R7KGhRMvQ1Qt7efcr-jQjKGu0jIht-yFUenpW8NxVNOfLZgAixsh7Of02gUkBZRBzOZkRTw9ZlePKnWLBFLX72WtD04FdxKOt4xG8Jp6/file# [following]\n",
"--2021-07-27 03:22:41-- https://ucb4542ecde226f32d4315b1fbc8.dl.dropboxusercontent.com/cd/0/inline/BTGvHQfXo4i1Zuda0OfZIn3M0sLD9saAggO_ol8huDdBTN96R7KGhRMvQ1Qt7efcr-jQjKGu0jIht-yFUenpW8NxVNOfLZgAixsh7Of02gUkBZRBzOZkRTw9ZlePKnWLBFLX72WtD04FdxKOt4xG8Jp6/file\n",
"Resolving ucb4542ecde226f32d4315b1fbc8.dl.dropboxusercontent.com (ucb4542ecde226f32d4315b1fbc8.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f\n",
"Connecting to ucb4542ecde226f32d4315b1fbc8.dl.dropboxusercontent.com (ucb4542ecde226f32d4315b1fbc8.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: /cd/0/inline2/BTG9WUVj6lUWIAp6GbSW-0BY6CLVIUaQD4UYVCKD2JSr0Ar8-dInWmMT3bXsc9xaHzLND2a2W3FrjDrM3aCBiRHIrpLSNqF9KdzroBiNmQF11eUUPszcj9GdSmTht-W9NMJfCYbBf5HOg9ldLUtex9mMEEV3LBMqT-qvGCNRYfXh9LWv8VOcrZM8JnebcmMFALgKdBl8FqbycGb0FkAhzHXshOhQvcWF1tdwE7VxyrVe2wT-B5RsuU8ClOuz0bY7nWyBIbyMFNNh1V28Qy3DSSTU3c74ULwTRMxlCHSN5dtZf3xvV99Kb57vkiTF8a888gyhO3C7F4TsGERtZxs9FXvMyKx990HfO0ORj-iTVw07akfIJN2jAyP6qmB3AyDPlSk/file [following]\n",
"--2021-07-27 03:22:42-- https://ucb4542ecde226f32d4315b1fbc8.dl.dropboxusercontent.com/cd/0/inline2/BTG9WUVj6lUWIAp6GbSW-0BY6CLVIUaQD4UYVCKD2JSr0Ar8-dInWmMT3bXsc9xaHzLND2a2W3FrjDrM3aCBiRHIrpLSNqF9KdzroBiNmQF11eUUPszcj9GdSmTht-W9NMJfCYbBf5HOg9ldLUtex9mMEEV3LBMqT-qvGCNRYfXh9LWv8VOcrZM8JnebcmMFALgKdBl8FqbycGb0FkAhzHXshOhQvcWF1tdwE7VxyrVe2wT-B5RsuU8ClOuz0bY7nWyBIbyMFNNh1V28Qy3DSSTU3c74ULwTRMxlCHSN5dtZf3xvV99Kb57vkiTF8a888gyhO3C7F4TsGERtZxs9FXvMyKx990HfO0ORj-iTVw07akfIJN2jAyP6qmB3AyDPlSk/file\n",
"Reusing existing connection to ucb4542ecde226f32d4315b1fbc8.dl.dropboxusercontent.com:443.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 228802825 (218M) [application/zip]\n",
"Saving to: cats_dogs_dataset.zip\n",
"\n",
"cats_dogs_dataset.z 100%[===================>] 218.20M 122MB/s in 1.8s \n",
"\n",
"2021-07-27 03:22:45 (122 MB/s) - cats_dogs_dataset.zip saved [228802825/228802825]\n",
"\n"
]
}
],
"source": [
"# 下载分类数据集文件\n",
"!wget https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0 -O cats_dogs_dataset.zip\n",
"!mkdir data\n",
"!unzip -q cats_dogs_dataset.zip -d ./data/"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "15iKNG0SlV9y"
},
"source": [
"### 根据数据集创建配置文件\n",
"\n",
"详细过程说明参考 MMClassification 教程,这里我们直接配置好微调的配置文件。"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "WCfnDavFlWrK"
},
"outputs": [],
"source": [
"# 载入已经存在的配置文件\n",
"from mmcv import Config\n",
"cfg = Config.fromfile('configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py')\n",
"\n",
"# 修改模型分类头中的类别数目\n",
"cfg.model.head.num_classes = 2\n",
"cfg.model.head.topk = (1, )\n",
"\n",
"# 加载预训练权重\n",
"cfg.model.backbone.init_cfg = dict(type='Pretrained', checkpoint='checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', prefix='backbone')\n",
"\n",
"# 根据你的电脑情况设置 sample size 和 workers \n",
"cfg.data.samples_per_gpu = 32\n",
"cfg.data.workers_per_gpu = 2\n",
"\n",
"# 指定训练集路径\n",
"cfg.data.train.data_prefix = 'data/cats_dogs_dataset/training_set/training_set'\n",
"cfg.data.train.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"\n",
"# 指定验证集路径\n",
"cfg.data.val.data_prefix = 'data/cats_dogs_dataset/val_set/val_set'\n",
"cfg.data.val.ann_file = 'data/cats_dogs_dataset/val.txt'\n",
"cfg.data.val.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"\n",
"# 指定测试集路径\n",
"cfg.data.test.data_prefix = 'data/cats_dogs_dataset/test_set/test_set'\n",
"cfg.data.test.ann_file = 'data/cats_dogs_dataset/test.txt'\n",
"cfg.data.test.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"\n",
"# 设定数据集归一化参数\n",
"normalize_cfg = dict(type='Normalize', mean=[124.508, 116.050, 106.438], std=[58.577, 57.310, 57.437], to_rgb=True)\n",
"cfg.data.train.pipeline[3] = normalize_cfg\n",
"cfg.data.val.pipeline[3] = normalize_cfg\n",
"cfg.data.test.pipeline[3] = normalize_cfg\n",
"\n",
"# 修改评价指标选项\n",
"cfg.evaluation['metric_options']={'topk': (1, )}\n",
"\n",
"# 设置优化器\n",
"cfg.optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"cfg.optimizer_config = dict(grad_clip=None)\n",
"\n",
"# 设置学习率策略\n",
"cfg.lr_config = dict(policy='step', step=[1])\n",
"cfg.runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"\n",
"# 设置工作目录以保存模型和日志\n",
"cfg.work_dir = './work_dirs/cats_dogs_dataset'\n",
"\n",
"from mmcls.apis import set_random_seed\n",
"# 设置随机种子,并启用 cudnn 确定性选项以保证结果的可重复性\n",
"cfg.seed = 0\n",
"set_random_seed(0, deterministic=True)\n",
"\n",
"cfg.gpu_ids = range(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HDerVUPFmNR0"
},
"source": [
"### 模型微调\n",
"\n",
"基于我们修改的配置文件,开始对我们的数据集进行模型微调计算。 我们调用 `train_model` API 进行计算. "
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "P7unq5cNmN8G",
"outputId": "0d7d9c26-842c-4e24-e958-4fdc8438c3e2"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-07-27 03:22:48,408 - mmcv - INFO - load backbone in model from: checkpoints/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth\n",
"2021-07-27 03:22:48,497 - mmcls - INFO - Start running, host: root@2fbef59c1bbe, work_dir: /content/mmclassification/work_dirs/cats_dogs_dataset\n",
"2021-07-27 03:22:48,499 - mmcls - INFO - Hooks will be executed in the following order:\n",
"before_run:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(NORMAL ) CheckpointHook \n",
"(NORMAL ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_train_epoch:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(NORMAL ) EvalHook \n",
"(LOW ) IterTimerHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_train_iter:\n",
"(VERY_HIGH ) StepLrUpdaterHook \n",
"(NORMAL ) EvalHook \n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_train_iter:\n",
"(ABOVE_NORMAL) OptimizerHook \n",
"(NORMAL ) CheckpointHook \n",
"(NORMAL ) EvalHook \n",
"(LOW ) IterTimerHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"after_train_epoch:\n",
"(NORMAL ) CheckpointHook \n",
"(NORMAL ) EvalHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_val_epoch:\n",
"(LOW ) IterTimerHook \n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"before_val_iter:\n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_val_iter:\n",
"(LOW ) IterTimerHook \n",
" -------------------- \n",
"after_val_epoch:\n",
"(VERY_LOW ) TextLoggerHook \n",
" -------------------- \n",
"2021-07-27 03:22:48,503 - mmcls - INFO - workflow: [('train', 1)], max: 2 epochs\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Use load_from_local loader\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-07-27 03:23:07,686 - mmcls - INFO - Epoch [1][100/201]\tlr: 1.000e-02, eta: 0:00:57, time: 0.191, data_time: 0.108, memory: 1709, loss: 0.7473\n",
"2021-07-27 03:23:24,069 - mmcls - INFO - Epoch [1][200/201]\tlr: 1.000e-02, eta: 0:00:35, time: 0.164, data_time: 0.077, memory: 1709, loss: 0.4259\n",
"2021-07-27 03:23:24,098 - mmcls - INFO - Saving checkpoint at 1 epochs\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[>>>>>>>>>>>>>>>>>>>>>>>>>>] 1601/1601, 141.2 task/s, elapsed: 11s, ETA: 0s"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-07-27 03:23:35,521 - mmcls - INFO - Epoch(val) [1][51]\taccuracy_top-1: 91.0056\n",
"2021-07-27 03:23:55,139 - mmcls - INFO - Epoch [2][100/201]\tlr: 1.000e-03, eta: 0:00:18, time: 0.196, data_time: 0.107, memory: 1709, loss: 0.2794\n",
"2021-07-27 03:24:12,312 - mmcls - INFO - Epoch [2][200/201]\tlr: 1.000e-03, eta: 0:00:00, time: 0.172, data_time: 0.078, memory: 1709, loss: 0.2882\n",
"2021-07-27 03:24:12,338 - mmcls - INFO - Saving checkpoint at 2 epochs\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[>>>>>>>>>>>>>>>>>>>>>>>>>>] 1601/1601, 145.5 task/s, elapsed: 11s, ETA: 0s"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-07-27 03:24:23,427 - mmcls - INFO - Epoch(val) [2][51]\taccuracy_top-1: 93.6290\n"
]
}
],
"source": [
"import time\n",
"import mmcv\n",
"import os.path as osp\n",
"\n",
"from mmcls.datasets import build_dataset\n",
"from mmcls.models import build_classifier\n",
"from mmcls.apis import train_model\n",
"\n",
"# 创建工作目录\n",
"mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))\n",
"# 创建分类器\n",
"model = build_classifier(cfg.model)\n",
"model.init_weights()\n",
"# 创建数据集\n",
"datasets = [build_dataset(cfg.data.train)]\n",
"# 添加类别属性以方便可视化\n",
"model.CLASSES = datasets[0].CLASSES\n",
"# 开始微调\n",
"train_model(\n",
" model,\n",
" datasets,\n",
" cfg,\n",
" distributed=False,\n",
" validate=True,\n",
" timestamp=time.strftime('%Y%m%d_%H%M%S', time.localtime()),\n",
" meta=dict())"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 652
},
"id": "HsoGBZA3miui",
"outputId": "c1a42eaa-588b-4431-fc02-bfedfed32c4f"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/content/mmclassification/mmcls/models/classifiers/base.py:221: UserWarning: show==False and out_file is not specified, only result image will be returned\n",
" warnings.warn('show==False and out_file is not specified, only '\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 576x432 with 0 Axes>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAm4AAAJCCAYAAAB5xkteAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy92a+tSXYn9IvhG/beZ745VVXW5KrG7jZCagQ2UoMQIBBS04IXJDdI/dhP/AH8HTz2Aw88IZ5AVre6LfGEWgJZciNkoGzsomtwVt7MvGfawzdFxOJhrRUR37430zam6LR8InXy3LOHb4hvxVq/9VtDGCLCy3gZL+NlvIyX8TJexsv4+g/7L/oCXsbLeBkv42W8jJfxMl7Gn228ALeX8TJexst4GS/jZbyMvyTjBbi9jJfxMl7Gy3gZL+Nl/CUZL8DtZbyMl/EyXsbLeBkv4y/JeAFuL+NlvIyX8TJexst4GX9Jxgtwexkv42W8jJfxMl7Gy/hLMn5pwM0Y8x8aY/7AGPNHxpj/8pd1npfxMl7Gy3gZL+NlvIy/KsP8Mvq4GWMcgD8E8O8D+DmA3wXwd4no//j//GQv42W8jJfxMl7Gy3gZf0XGL4tx+w0Af0REPyaiGcB/C+A//iWd62W8jJfxMl7Gy3gZL+OvxPC/pON+C8DPqr9/DuA3v+zDzntqmmb1msn/MOsP0/oT1lk450BEoJSQUkIiAlECCDDWwBgDay0MAJIDxJQAIhhj14etzk0AKCV5wcilyFFIL88ARPn7xgCUCEQE5z2cs3xNia+JCLDGAMbwsf8Mw8Dka3jn+zJHBobvj8DzUbGpxpj8A8Pvg6ia3+p78u/VZJxdjx5Hrg4wRq7T6Cvy77cPYuT19bUgn9fY6hiQ96sn9GUccX4KVP377Q+t58wYlGer81Y+mJ/z2bnz3OZfBGstmqYFpYQQgxzXlGuT+VcZ1UfqrIO1BsZYkZ8ESoQYA4gIPE0GVubNOQdjDRrfID8GAmKM8roHESHFCGMsnHdZHogIKRFUpPQ+NpsebdchhoiUIs6Z+CivpcQyrLJrrOVzp5jnIs8DEUIIIBAo8fw47+Cdg/eNrIuIECJijDI/1aOq5hYAvPMw1iDECADouw7GWsQQEFPCMAwAgK7rYAAkvcazNZuXEwDnPRrfIqWU7xGELLv0jgXAc5tAIHjvARjEuJ4za1Wg+N9t00JUBWKKmKeZn521+ZjWWljnWI+pPiNC03g45+Gcg7UWIQakmLAsy2renXcwxsI7B2OAFPn+nZxDdR5RyrJIRHLtKV+DF/lRHZHvKqsFAvLKNjKfZvX8Up6PSo6Mrn2T191qXqHyyO+RLBAiQiJCTISo8gfVH+U4edmafDr+fW5DqlF0F6APKOvAfM+UhZHPQat7L8fhOaZEq8+truHsUmq9rGtWv5ufhawfAhU5sWpXUp5n1h8my3CSZ6rPUO2O2sKVnSPWNwCvX2MsnLNZD/Paj1kujTEw1sJZC2udPCO+95RSvod67tWe1qbFWNZ5xqh91nkpX9Jj60rmc5v8Yb3P/Dzkmem98zkMrx9jBB8QwhLETvPremK1T/XcmNVFIWON/BythbOVYqEk2IEvsu87tE0jeh7y28A7K3NEWc5qzfujH/3fXxDR+zgbvyzg9qcOY8zfB/D3AcA3Ht/74feqd4vAMLCyyIuTAF7R/PfFxQWurq6wLDOWZcbpdMI4jliWCSEu6LoO3ntstz2894gpIKWI4+mEGBOapgNgEIkFyspDVgN7Op3EEDUZAMbIhsZaBo0pJcQY4ZyDcw7zPGOeZ9zd3eHi4gLDMGCeZ0zThBgjvPewxmIZJyQ5jsyJ3CPlv1Xg+N+uvC+6RK/BGF5Aem0xRjaacqymaeCcQ9s28J6vMabISt1AjDUbA72f+hr0GvWZNE3LrxsHax2c87DWw9lGFAu/zgvS8TMVhOGthTNWDJJFCCFfMxGh7VpY5+RshGVZUJuMmGKluJHBnyqxAlASSNa3FeCTIoES0DgHZ102hjr3IUSkVB0rKb5dPxsASEmvDSKLl/j4429jmibc39+XeTQGcBbJGiRnEFLEEiMoBqQQcHmxw6bvcXVxia7rcHh6wDyNOD7vscwzLBIsgG3TofEON1cX6LsOH756D03j0TqLFCIOhz26psV7r24Rl4Djfo+263BxeYmYEuYlYAkB0zzDWgPrDEJYEGPAr//6r+N73/se7u/f4HDYI8QFKRXZnKYJIQQcDgcsy4LjcQAI2G63IAIOhyEbkr7f4KMPv4EQAh4eHkQWIzabDW5ubnB9fY0PPvgAp9MJ+/0eD4+PeHh8gPce3ntWrCD5jSzHd6/u0HYdHp6eQET41V/7VWw2Gzw+PeFwOOBHP/oRUkr4+OOPYYzBOPA1OohB1mdRGcvr6xu8996HGMcRp9MJISSEGEWmHRKx7lGQkCID3/1+jxgjbu/eg/Mejw+PWJYFRAbWGmz6DgAhThM2fYfvfPtjAVQGz8/P+OlPfwrnPTa7bZaTpm3R9R1CCFiWBeM8YRI9cnV9jdvbW+x2Ozw8PuA0DPjFL36B4/GY52d7sUPbtnjv7hW8c5gPJxgAu+0GIMIwDEgpYgkznLNo2xbzPOH5+QnLMmMcB+x2O9zc3ORrKOsJ8sNrw8DKmlYd1cIaC+cNgITT6SCORwDDX15/3jMAbZpODKQDERDEuV1iAowHfAsyFsk4jHPAMAc8DwP2w4g5AEsEbNvB+gZwDjAGS2D5cxYrnaX6E8WKrECdt07ugcHQssz5PlOKSDFk2QYRkKi6b8trSXTIMI1YRJ+lyrhnHW/XDpH3HsYYNE2DGCOOx0OWh82mx6tXdwghZNsRQsBms8Fut8M4DhingQF8jGjbHo1vsj6d5xEh8r8pEVqZ+77v2W7IJMQYsYQFT0+PMMbg4uICbdPgcndR9OKy4HQ6IISAeZ7RNA3atkXf99hut5jngHlZME0T5nnO965zr/MxjiM7mPJ8NpsNvPfo+z5/rn52KaV8/2oL2Y61+Rkty4xpGvN7tf3T+WrbFpeXlyLzM0II+OKLLzDPM25ubtC2bT5v13UrDKA63mWbBMzLjNPxkIFa37fY7TYwSLCGEJYByzzC2ghjEv76r/4A3/74G+h7h7ax2Gw82sbh9voCXdsgpgkpRcSoa47t77/xt/7eT/CO8csCbn8C4NvV3x/La3kQ0T8A8A8AoN9s1m7Zn2PoxIYQKu+jIHcFIApIEkXx8lmglyUCMHBtkwVHBWoRQfTeo+s6AFiBIV2U+rdzLn+uFj5dnCpU0zRhDnNmLZJ6MBVQK4CssDasN6iAKtiKWUtIybzFlGQPrJqvGGuW68874zy7KSVYY0FW/m0Lw5eZEmGLSBSleoYxRCREcXBtfjYhsMdnnIWVOWHWRoTZmvX9C3AzttzBmvWwzGwa9jadsSALgEx2DXSu53kWUM3AVherb5hR0WtTw3NxcSHAhJVd3/e4u3uFH/7wB5jnBY+Pj3jz5g1+8pOfYLPd4vLmGk/HA572T3j/ww/w6v338cnPf4bPX7+GswYpRnzjw4/w6tUrhGnAMk9omgYWwHA8IKUITwaGGnUDMc8ziBKavoO1BtvNFs7xfFJKaBpmtZ6fn+G8R9v1Gbg4Z2GdwbJYLAvL/TiOOBz2eHp6QkwBRAlt28I5h+12C2MMTuLwDMMASoSu6+Ccw2bT59ejADw1QMZY9D0repVrdSTOZdUYZqJjirDOwRmDURRw1PUtsnF5eYnLy0ter32P5Qc/yM+NUmIQGCPiHLLzRUQIlad8Op3w+vUvKscnIIQIIIpH3MA6D+saGGthDQP27XYrhol57uLEQUBRhxQD9qd7pBgwDgM2mw2ur68RY8R2u0Xbdbi+vck6YRhHvHnzBpvNBpvNBrD8rE6nE4ZhQNM04kDyvVxcXMA5l4368+MTAGDTdOjaDkgEawyCzIm1gLUOTbuBtVaercWyzOKQsoyP45ifnQISXieFuU0JoIQKoEzsUM/Cchqg8Y4jDcIY6fosa90AJsmaTogJWGKC9QZNBbCVYcq
"text/plain": [
"<Figure size 1080x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"# 验证一下训练好的模型\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"img = mmcv.imread('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')\n",
"\n",
"model.cfg = cfg\n",
"result = inference_model(model, img)\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"show_result_pyplot(model, img, result)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "MMClassification_python_cn.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}