mmpretrain/docs_zh-CN/tutorials/MMClassification_Tutorial.i...

2354 lines
1.7 MiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "FVmnaxFJvsb8"
},
"source": [
"# MMClassification Tutorial\n",
"Welcome to MMClassification!\n",
"\n",
"In this tutorial, we demo\n",
"* How to install MMCls\n",
"* How to do inference and feature extraction with MMCls trained weight\n",
"* How to train on your own dataset and visualize the results. \n",
"* How to use command line tools"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QS8YHrEhbpas"
},
"source": [
"## Install MMClassification\n",
"This step may take several minutes.\n",
"\n",
"We use PyTorch 1.5.0 and CUDA 10.1 for this tutorial. You may install other versions by change the version number in pip install command."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 170
},
"colab_type": "code",
"id": "UWyLrLYaNEaL",
"outputId": "35b19c63-d6f3-49e1-dcaa-aed3ecd85ed7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nvcc: NVIDIA (R) Cuda compiler driver\n",
"Copyright (c) 2005-2019 NVIDIA Corporation\n",
"Built on Wed_Oct_23_19:24:38_PDT_2019\n",
"Cuda compilation tools, release 10.2, V10.2.89\n",
"gcc (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609\n",
"Copyright (C) 2015 Free Software Foundation, Inc.\n",
"This is free software; see the source for copying conditions. There is NO\n",
"warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
"\n"
]
}
],
"source": [
"# Check nvcc version\n",
"!nvcc -V\n",
"# Check GCC version\n",
"!gcc --version"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
},
"colab_type": "code",
"id": "Ki3WUBjKbutg",
"outputId": "69f42fab-3f44-44d0-bd62-b73836f90a3d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://mirrors.aliyun.com/pypi/simple\n",
"Looking in links: https://download.pytorch.org/whl/torch_stable.html\n",
"Requirement already up-to-date: torch==1.5.0+cu101 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (1.5.0+cu101)\n",
"Requirement already up-to-date: torchvision==0.6.0+cu101 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (0.6.0+cu101)\n",
"Requirement already satisfied, skipping upgrade: numpy in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from torch==1.5.0+cu101) (1.19.2)\n",
"Requirement already satisfied, skipping upgrade: future in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from torch==1.5.0+cu101) (0.18.2)\n",
"Requirement already satisfied, skipping upgrade: pillow>=4.1.1 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from torchvision==0.6.0+cu101) (8.0.1)\n",
"Looking in indexes: https://mirrors.aliyun.com/pypi/simple\n",
"Looking in links: https://download.openmmlab.com/mmcv/dist/cu101/torch1.5.0/index.html\n",
"Requirement already satisfied: mmcv-full in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (1.2.7)\n",
"Requirement already satisfied: Pillow in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from mmcv-full) (8.0.1)\n",
"Requirement already satisfied: opencv-python>=3 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from mmcv-full) (4.5.1.48)\n",
"Requirement already satisfied: yapf in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from mmcv-full) (0.30.0)\n",
"Requirement already satisfied: numpy in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from mmcv-full) (1.19.2)\n",
"Requirement already satisfied: addict in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from mmcv-full) (2.4.0)\n",
"Requirement already satisfied: pyyaml in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from mmcv-full) (5.3.1)\n"
]
}
],
"source": [
"# Install PyTorch\n",
"!pip install -U torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html\n",
"# Install mmcv\n",
"# !pip install mmcv-full\n",
"# !pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html\n",
"!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.5.0/index.html"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 374
},
"colab_type": "code",
"id": "nR-hHRvbNJJZ",
"outputId": "ca6d9c48-0034-47cf-97b5-f31f529cc31c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fatal: destination path 'mmclassification' already exists and is not an empty directory.\n",
"/home/SENSETIME/shaoyidi/VirtualenvProjects/add_tutorials/MMCls_Tutorials/mmclassification\n",
"Looking in indexes: https://mirrors.aliyun.com/pypi/simple\n",
"Obtaining file:///home/SENSETIME/shaoyidi/VirtualenvProjects/add_tutorials/MMCls_Tutorials/mmclassification\n",
"Requirement already satisfied: matplotlib in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from mmcls==0.9.0) (3.3.2)\n",
"Requirement already satisfied: numpy in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from mmcls==0.9.0) (1.19.2)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from matplotlib->mmcls==0.9.0) (2.8.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from matplotlib->mmcls==0.9.0) (1.3.0)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from matplotlib->mmcls==0.9.0) (2.4.7)\n",
"Requirement already satisfied: cycler>=0.10 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from matplotlib->mmcls==0.9.0) (0.10.0)\n",
"Requirement already satisfied: certifi>=2020.06.20 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from matplotlib->mmcls==0.9.0) (2020.6.20)\n",
"Requirement already satisfied: pillow>=6.2.0 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from matplotlib->mmcls==0.9.0) (8.0.1)\n",
"Requirement already satisfied: six>=1.5 in /home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages (from python-dateutil>=2.1->matplotlib->mmcls==0.9.0) (1.15.0)\n",
"Installing collected packages: mmcls\n",
" Attempting uninstall: mmcls\n",
" Found existing installation: mmcls 0.9.0\n",
" Uninstalling mmcls-0.9.0:\n",
" Successfully uninstalled mmcls-0.9.0\n",
" Running setup.py develop for mmcls\n",
"Successfully installed mmcls\n"
]
}
],
"source": [
"# Install mmcls\n",
"!git clone https://github.com/open-mmlab/mmclassification.git\n",
"%cd mmclassification\n",
"\n",
"!pip install -e ."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "mAE_h7XhPT7d",
"outputId": "912ec9be-4103-40b8-91cc-4d31e9415f60"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.5.0+cu101 True\n",
"0.9.0\n"
]
}
],
"source": [
"# Check Pytorch installation\n",
"import torch, torchvision\n",
"print(torch.__version__, torch.cuda.is_available())\n",
"\n",
"# Check MMClassification installation\n",
"import mmcls\n",
"print(mmcls.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eUcuC3dUv32I"
},
"source": [
"## Use MMCls pretrained models\n",
"\n",
"MMCls provides many pretrained models in the [model zoo](https://github.com/open-mmlab/mmclassification/blob/master/docs/model_zoo.md).\n",
"These models are already trained to state-of-the-art accuracy on ImageNet dataset.\n",
"We can use pretrained models to classify images or extract image features for downstream tasks.\n",
"\n",
"To use a pretrained model, we need to:\n",
"\n",
"- Prepare the model\n",
" - Prepare the config file \n",
" - Prepare the parameter file\n",
"- Build the model in Python\n",
"- Perform inference tasks, such as classification or feature extraction. \n",
"\n",
"### Prepare Model Files\n",
"\n",
"A pretrained model is defined with a config file and a parameter file. The config file defines the model structure and the parameter file stores all parameters. \n",
"\n",
"MMCls provides pretrained models in separated pages on GitHub. \n",
"For example, config and parameter files for ResNet50 is listed in [this page](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet).\n",
"\n",
"As we already clone the config file along with the repo, what we need more is to download the parameter file manually. By convention, we store the parameter files into the `checkpoints` folder. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mkdir: cannot create directory checkpoints: File exists\n",
"--2021-03-11 17:14:56-- https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth\n",
"Connecting to 172.16.1.135:3128... connected.\n",
"Proxy request sent, awaiting response... 200 OK\n",
"Length: 102491894 (98M) [application/octet-stream]\n",
"Saving to: checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth.2\n",
"\n",
"resnet50_batch256_i 100%[===================>] 97.74M 9.98MB/s in 9.7s \n",
"\n",
"2021-03-11 17:15:07 (10.1 MB/s) - checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth.2 saved [102491894/102491894]\n",
"\n"
]
}
],
"source": [
"!mkdir checkpoints\n",
"!wget https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth -P checkpoints"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Perform inference\n",
"\n",
"MMCls provides high level APIs for inference. \n",
"\n",
"First, we need to build the model."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "H8Fxg8i-wHJE"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Use load_from_local loader\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/SENSETIME/shaoyidi/VirtualenvProjects/add_tutorials/MMCls_Tutorials/mmclassification/mmcls/apis/inference.py:44: UserWarning: Class names are not saved in the checkpoint's meta data, use imagenet by default.\n",
" warnings.warn('Class names are not saved in the checkpoint\\'s '\n"
]
}
],
"source": [
"from mmcls.apis import inference_model, init_model, show_result_pyplot\n",
"\n",
"# Specify the path to config file and checkpoint file\n",
"config_file = 'configs/resnet/resnet50_b32x8_imagenet.py'\n",
"checkpoint_file = 'checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth'\n",
"# Specify the device. You may also use cpu by `device='cpu'`.\n",
"device = 'cuda:0'\n",
"# Build the model from a config file and a checkpoint file\n",
"model = init_model(config_file, checkpoint_file, device=device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we use the model to classify the sample image. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "izFv6pSRujk9"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"# Test a single image\n",
"img = 'demo/demo.JPEG'\n",
"result = inference_model(model, img)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's checkout the result!"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 504
},
"colab_type": "code",
"id": "bDcs9udgunQK",
"outputId": "8221fdb1-92af-4d7c-e65b-c7adf0f5a8af"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/SENSETIME/shaoyidi/VirtualenvProjects/add_tutorials/MMCls_Tutorials/mmclassification/mmcls/models/classifiers/base.py:216: UserWarning: show==False and out_file is not specified, only result image will be returned\n",
" warnings.warn('show==False and out_file is not specified, only '\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAv0AAAJCCAYAAABTWni0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9d7h1yVEfjP6qe629T3jT5DwaSTMaZQmwCQZ92L7mMQZMzrYvfDafuWSEDIhoy5hgbINIQiKZ9NmAEEkY7M/kLAshUBxpRhKjMDm84cS91+q6f1To6rX3OTPC6N55eE5L75xz9l6rQ3V11a+qq6uJmXFSTspJOSkn5aSclJNyUk7KSfmbW9L/vztwUk7KSTkpJ+WknJSTclJOykn5wJYT0H9STspJOSkn5aSclJNyUk7K3/ByAvpPykk5KSflpJyUk3JSTspJ+RteTkD/STkpJ+WknJSTclJOykk5KX/DywnoPykn5aSclJNyUk7KSTkpJ+VveDkB/SflpJyUk3JSTspJOSkn5aT8DS8fMNBPRB9LRG8joruI6MUfqHZOykk5KSflpJyUk3JSTspJOSnHF/pA5Oknogzg7QA+BsB7AbwWwOcw81v+2hs7KSflpJyUk3JSTspJOSkn5aQcWz5Qnv4PBXAXM7+TmRcAfgbAJ32A2jopJ+WknJSTclJOykk5KSflpBxTug9QvTcAeE/4+70APuyoh8+cPctXX3Ot/000fYIe4zt+zGenGxrUPvrXUuiYCpse0mp/2ifW1KxVy3us9ZB+xivP/bUVZqwj1JG0PbZ9bh9aMyfH7jutq/vxbFTZe3+FTa1mnEc99LjHP32B1ldKWCGVveYfP0Y70x08av+ztisrz/LjJ9l6Ek86v/JOHSiHT4+rZX1Nj7N/zfi54Vtplb1X07biq/aeP3kkfz1WT+nIv/hxfXFsdfJ4lBfTh5sPphVPGH/ah+P478iqjmL2Y75+f8pjrfNmDKuLtvLAce+vl132YcvH9fN13WtIGAT8ceJk5TsiXaexZZq889e/m1/brx07bnaPEGfvZzlabq4bYysj18wbxV/Wz9Fq2/LU2nG4YGhlTFvnMTqa649jxfsRTHEU99lfzfdHzNtfTVXyX0EmEJhZ+/G/D1oeEzsc9d5aLPa/2ZeJnljzxGPU8Pg61Dw1aezOt73tIWa+at17HyjQ/5jwjIj+JYB/CQBXXX01/tMPvAwAkFICEfk/YQwK/7QyZn82ApyUkv8df6ew0ogIxAxZBuuFvyzOozdChFki2F4HjCOSMpAxEe5T4EvyXnzUxm7jno6plOK/T/sR/2YdMwGgsMmzAhBTGB8zgBSAjoi8I9mWCCNY26VAJwYw+ufMBIwM4tU+14aEDmy0mmCsKT0ez/iPe7Z5T4V1SoFOYDCK942Qvb51fFbfAwgslYZ+pJTBRQQmufKO/BloAHmngIX+tKrchd7S8XEcmnVg/ZI1QQBTmBeAUhSbjJwzSilHjiuOO/InWOY3dKnOoT6bdL2QEkeoWuelGjakpI7WTuBXBmyzkgODkiwigEsdn60PpDrusL4AoJShtksU5Cg1YyxjgS2NlJPPxVhGHb/yDReZxxUlTV6ndMXmKQM2z1x8HBxx8Vp5ZiQnXysNjZgBKv5eB3LNxKnKzxLlqM8nUAqQc16Z87ZPhJRI6gh8ZXNSOzMhBchlEZcWtFD83edjtY4ymcepPKvrLVWe837CPycCxlKgy2O1y1QNpeTLWdEOA0Wkwxp4yGCMLktMXq/oJCKM49j0/ah1V99Jvpan8zIdP1n/iVFKQaPjuBrBRCpx1qCWaZ9X2vFxVzrFfvu4lVfi58eGGrOBVpWTE73RyKqk8pGmvJpatnQdRWCu6wOT8a/TuQQgoeW3ZixSA4ZxRM45jC95m3Fs8Q+fG6NjrJcITOx6gpkBXzcVz7RYWsZduACo4zxqvXhPpusoYAbRURDJzcXpdxygrbhK3i+TsZHKKlKZV8Vvlh8TXmJmxypE1PDASttxrEdgpfh+1JnraOHth7qma9P76agpvqvfHINBWI3G1lgLfSp1vqf9+4cf+ZF3H1XvBwr0vxfATeHvGwHcEx9g5h8C8EMAcOvTnsYOCFQwRtDPzEiUwsQVq2Ol4fhZu1gBojiRETisI+qq4mi/P1oY23OsKs24laxP9vmKAKrtJhWMjyV4134+8RaFTh63JhE71OCqthFVmKvt+k+uwLSqvSg0DAlqNbRujAKmGijouPTYQUzqWV2M0+/XvndUKwGEGRiJ9TRtaYfJ3lvhM5sPar+PX0EBiOo7M7hMhDSiJICrKNRX15Q9U18ikEIWrIwjlsdacwRqaF6xOIkyDn2uU0lA0bU6Jfp0/lYRdAMG1aQI79ZnnA8DmSmlFXAQDVanT5BHZnhU+WIguoU8hOQAovnODDxbY2rcMDFKYTAXlxU2jkqO1d8N9ItYVJA5VUhxDIHs/j1XGgk9gJRIjZuqDNnHEg16pT8DFJ4wgOFG7dpSBUH7iNVnNKsT19bFYV5b2R/1hyjQCZhqkZEq2HZdcZwoqr8CDCYG8VHjOnq0K58dB3YfowTbrIoSVFAR5ac/g8rjTV1Nf1YdJMcZINaH6QfW5nHgJrZxlKxmAMTBCUctL0wdfDK82KOwVsK45e8p37R/rytKIdT1PNH/bbWVfsEpEnnJ3modKLWy2u2KZ6rhyu26PrJEXl+l83F8yLEvVKUwuexLa98/0kgkQnIsscpnJm1Wex7refzrpukbT2oPhsDa5yfvxfmeypl17zsGdNikYH6NkT7lTzHGi3/ZjDzI3uOiTKblAwX6XwvgNiJ6MoD3AfhsAJ971MNE5NZwJGC0qgFUizYoq2k9q3VHQWAKpk5Yq5AZQLY3vY51yuSxCnNx0OpqhgCgAMxIOSN6HqzOcVRv+BomrIIjMrB6HAzNOPYQnxNzhYk+RGrVQUuwtk39pf3i2HGr994bqoKp2jnyHSeEhTDtRmhzjVHgzwVF8XgB/eMq1C4lBhqPKxgoKA5yj6lGx8grJBSBd7RQtApI/2Ze9RkcJf5MCa5T1jR5/ijv4HF/H12Omasj2Kjx0PD0BfZnAPE4r7TnOwHUfGwKqq6BdixESYC5ywD5nfmxw+RUz/m7vGYebb5agwCr64kAtncgv6f4Xfh9pWMRhKIgypQEEk9YAkqRGlOwpM2LZIDfu6a/j8HzLrykDRJASXmx1dnygX449eIe7axYVbqObxhVboe+k9JexlCCI8JknQJ9FJd1hVXXRBSl9RWOxkoAx9oPrhUABIxg9dkGmjTz4ZTz38pjrKH3F/xTIyPc59cOwAGD0fGIOdABVPDS9uV/S5YeVcyd24i9Ixxd1kd7Fa0OW5FzK5MR1/hqMflihm6s6jjg7/O+BnYRETpKaJyNlBSCrNLYiJFS8v6oplRScZABBPfamxw6DrSTrYpVwHpUOXb3hdtfzGHbOnwi9lqVvzzZSW6MCV2f1ShY009q9dt09+JYI0Ye8mfSGkDd0CcaL/p31cdryBKerbt9gv2iHGBu15bvduocRzZxe0CxnvEIEa0lz1HlAwL6mXkgoi8F8D8gKPrHmPnNx77UAFFjhLrNnCJmnViPhIArqVaYiMBomS1ax9bGEd2BLKrHR831iycuLhtbab4nW8QqCJhkm9get4kFgHxMuJEvHlBVglq/eKuU/QwEsfZvBUjwiuxYB6qZVrF4XEBSfXGmrh4BbUN7K5qzgpxmsZpwCHMK9dqZ99otbQ27MW/K1HAETDhiFSCuoydsWqIoaL31K+801n0ADGqMrROIRykL5qItRmHT/t6oQFrdWl8naKsXZX2/fWYm3osGyEaaGsAMg0gGuFzRxLUHBU7C88X4EHW7vLJJgDHrpkvrJ65b3g4Eua6t9YBfV3hQrhwaqrq1pUfTvIJ5dmCqQ+MCcALzCIauW39fDQQyfqSqfEhBOtc+VOqsEkFCaUoD5B0IwJSg0YT9246SglwOopBQd+fWAx02eWX12y8Jvi4fq3CY09rH1ckVGoxhXrSxiYKrRkvUHS39Gn5dM6rCxWkgU9M+z+FnGAhsY8U
"text/plain": [
"<Figure size 1080x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Show the results\n",
"show_result_pyplot(model, img, result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Feature extraction\n",
"\n",
"Feature extraction is another inference task. We can use pretrained model to extract sematic feature for downstream tasks. \n",
"MMClassifcation also provides such facilities. \n",
"\n",
"Assuming we have already built model with pretrained weights, there're more steps to do:\n",
"\n",
"1. Load the image processing pipeline. This is very important because we need to ensure data preprocessing during training and testing are the equivalent.\n",
"2. Preprocess the image. \n",
"3. Forward through the model and extract feature."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, we load image with test pipeline. "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from mmcls.datasets.pipelines import Compose\n",
"from mmcv.parallel import collate, scatter\n",
"\n",
"# Pack image info into a dict\n",
"data = dict(img_info=dict(filename=img), img_prefix=None)\n",
"# Parse the test pipeline\n",
"cfg = model.cfg\n",
"test_pipeline = Compose(cfg.data.test.pipeline)\n",
"# Process the image\n",
"data = test_pipeline(data)\n",
"\n",
"# Scatter to specified GPU\n",
"data = collate([data], samples_per_gpu=1)\n",
"if next(model.parameters()).is_cuda:\n",
" data = scatter(data, [device])[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we can use the API from model to get the feature."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 2048])\n"
]
}
],
"source": [
"# Forward the model\n",
"with torch.no_grad():\n",
" features = model.extract_feat(data['img'])\n",
"\n",
"# Show the feature, it is a 1280-dim vector\n",
"print(features.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Ta51clKX4cwM"
},
"source": [
"## Finetune pretrained model with customized dataset\n",
"\n",
"Finetuning is the process in which parameters of a model would be adjusted very precisely in order to fit with certain dataset. Compared with training, it can can save lots of time and reduce overfitting when the new dataset is small. \n",
"\n",
"To finetune on a customized dataset, the following steps are neccessary. \n",
"\n",
"1. Prepare a new dataset. \n",
"2. Support it in MMCls.\n",
"3. Create a config file accordingly. \n",
"4. Perform training and evaluation.\n",
"\n",
"More details can be found [here](https://github.com/open-mmlab/mmclassification/blob/master/docs/tutorials/new_dataset.md).\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "AcZg6x_K5Zs3"
},
"source": [
"### Prepare dataset\n",
"\n",
"Before we support a new dataset, we need download existing dataset first.\n",
"\n",
"We use [Cats and Dogs dataset](https://www.dropbox.com/s/wml49yrtdo53mie/cats_dogs_dataset_reorg.zip?dl=0) as an example. For simplicity, we have reorganized the directory structure for further usage. Origianl dataset can be found [here](https://www.kaggle.com/tongpython/cat-and-dog). The dataset consists of 8k images for training and 2k images for testing. There are 2 classes in total, i.e. cat and dog."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-03-11 17:15:22-- https://www.dropbox.com/s/ckv2398yoy4oiqy/cats_dogs_dataset.zip?dl=0\n",
"Connecting to 172.16.1.135:3128... connected.\n",
"Proxy request sent, awaiting response... 301 Moved Permanently\n",
"Location: /s/raw/ckv2398yoy4oiqy/cats_dogs_dataset.zip [following]\n",
"--2021-03-11 17:15:23-- https://www.dropbox.com/s/raw/ckv2398yoy4oiqy/cats_dogs_dataset.zip\n",
"Reusing existing connection to www.dropbox.com:443.\n",
"Proxy request sent, awaiting response... 302 Found\n",
"Location: https://uce2f1fc5c8344ac928f7a3e619f.dl.dropboxusercontent.com/cd/0/inline/BKfHBDoPAEY-QPjLw8I3a7UY8azZSDQ_wuT8ECxXciHPSimQTk-mXQFGx3I6nGOydUZWqVnJ1jQPz-lJSRTg6TFSr-n2lh3yvtC3m2wOXrZT8RhwgqXrQ_bvQwurPSIVc7XTfHBJIhyN5rzpfsXquNu6/file# [following]\n",
"--2021-03-11 17:15:23-- https://uce2f1fc5c8344ac928f7a3e619f.dl.dropboxusercontent.com/cd/0/inline/BKfHBDoPAEY-QPjLw8I3a7UY8azZSDQ_wuT8ECxXciHPSimQTk-mXQFGx3I6nGOydUZWqVnJ1jQPz-lJSRTg6TFSr-n2lh3yvtC3m2wOXrZT8RhwgqXrQ_bvQwurPSIVc7XTfHBJIhyN5rzpfsXquNu6/file\n",
"Connecting to 172.16.1.135:3128... connected.\n",
"Proxy request sent, awaiting response... 302 Found\n",
"Location: /cd/0/inline2/BKdw_s6y59fYYUAQhWUPoG4Fb4WhR2z6MK1nxmb4GDm4MIre2Yt8iwxMZh0JxGYRnYIOtIG7vs6e1HefsS-vzCp_-ab1Bfzcnon8FnmWom91NFQNPmpGRAWWrJa_VoRB_Z1iCfnrokxhECF0wQURulHHXdwLoC0Il0fh38pag8qrJOsPL5QgBFWCZO54yA6nuytf8IIJU3T76DtFE_cAPEaOIkJcx1ZfQEX0mPSDoWczuwxK9du3M1oQQTuVRKUZDleWArNaZq1xXz6xNS_vpGCVlP66E6VbfXaxCAvgGARLjUPov_9yBKpr_73VZSZr0GjHGPXVMfvHsM4-ZsQ2XlQ8Gie_Gfit4JpVyLeRhptwKpD0aeoBl2t0h6i9Wbfr_yo/file [following]\n",
"--2021-03-11 17:15:24-- https://uce2f1fc5c8344ac928f7a3e619f.dl.dropboxusercontent.com/cd/0/inline2/BKdw_s6y59fYYUAQhWUPoG4Fb4WhR2z6MK1nxmb4GDm4MIre2Yt8iwxMZh0JxGYRnYIOtIG7vs6e1HefsS-vzCp_-ab1Bfzcnon8FnmWom91NFQNPmpGRAWWrJa_VoRB_Z1iCfnrokxhECF0wQURulHHXdwLoC0Il0fh38pag8qrJOsPL5QgBFWCZO54yA6nuytf8IIJU3T76DtFE_cAPEaOIkJcx1ZfQEX0mPSDoWczuwxK9du3M1oQQTuVRKUZDleWArNaZq1xXz6xNS_vpGCVlP66E6VbfXaxCAvgGARLjUPov_9yBKpr_73VZSZr0GjHGPXVMfvHsM4-ZsQ2XlQ8Gie_Gfit4JpVyLeRhptwKpD0aeoBl2t0h6i9Wbfr_yo/file\n",
"Reusing existing connection to uce2f1fc5c8344ac928f7a3e619f.dl.dropboxusercontent.com:443.\n",
"Proxy request sent, awaiting response... 200 OK\n",
"Length: 228487605 (218M) [application/zip]\n",
"Saving to: cats_dogs_dataset.zip\n",
"\n",
"cats_dogs_dataset.z 100%[===================>] 217.90M 9.12MB/s in 26s \n",
"\n",
"2021-03-11 17:15:51 (8.37 MB/s) - cats_dogs_dataset.zip saved [228487605/228487605]\n",
"\n"
]
}
],
"source": [
"!wget https://www.dropbox.com/s/ckv2398yoy4oiqy/cats_dogs_dataset.zip?dl=0 -O cats_dogs_dataset.zip\n",
"!mkdir data\n",
"!unzip -q cats_dogs_dataset.zip -d ./data/cats_dogs_dataset/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The directory of the \"Cats and Dogs Dataset\" is as follows:\n",
"```\n",
"data/cats_dogs_dataset\n",
"├── training_set\n",
"│ ├── training_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.1.jpg\n",
"│ │ │ ├── cat.2.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.1.jpg\n",
"│ │ │ ├── dog.2.jpg\n",
"│ │ │ ├── ...\n",
"├── test_set\n",
"│ ├── test_set\n",
"│ │ ├── cats\n",
"│ │ │ ├── cat.4001.jpg\n",
"│ │ │ ├── cat.4002.jpg\n",
"│ │ │ ├── ...\n",
"│ │ ├── dogs\n",
"│ │ │ ├── dog.4001.jpg\n",
"│ │ │ ├── dog.4002.jpg\n",
"│ │ │ ├── ...\n",
"```\n",
"\n",
"You may also check the structure of dataset by `tree data/cats_dogs_dataset`."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 377
},
"colab_type": "code",
"id": "78LIci7F9WWI",
"outputId": "a7f339c7-a071-40db-f30d-44028dd2ce1c"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAFoCAYAAAC4+ecUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9yY8tW5beif12Y81pvL39fS8iX0Tmy45kkgCLRZYKEKSCSkVwopkgaaJZjTQX/wT9DTXQWJoKUgEaEAIhcEKyWKKYjMzIyIgX8Zrben86M9udBmtvM3O/9wWTElN8BVy7cFz348fPsWO292q+9a1vqZQSn45Px6fj0/Hp+HQA6P/YJ/Dp+HR8Oj4dn44fzvHJKXw6Ph2fjk/Hp2M8PjmFT8en49Px6fh0jMcnp/Dp+HR8Oj4dn47x+OQUPh2fjk/Hp+PTMR6fnMKn49Px6fh0fDrG46/NKSil/qFS6udKqb9USv3jv673+XR8Oj4dn45Px3+4Q/119CkopQzwF8B/CXwL/Avgf51S+tl/8Df7dHw6Ph2fjk/Hf7DjrytT+E+Bv0wp/SqlNAD/J+B/8df0Xp+OT8en49Px6fgPdNi/ptf9DPhm9vO3wN+fP0Ep9V8D/3X+/u/WTSOPywPypDGJURhjQEGMMX8FtNZorYFEjImUorxC/nutNSGE8fuHWZHWipTkPY01+TUSMUZISd5fjeeLKueF/Lo8oZxzSuneeyil0PlxOac0PWf22uPr53OX99LTz7MnKkBpRcrnqrT8PuXXRiGPjycup5nKxVSzV0vT54oxopSeruH4fnLtnXNorVkslyilcG6g7wdsZRmco2lbhr4npURVWZqmoe86YgjEENBKYY3BWos1Jl/7hFHlvqT82eVnle/1YrEAEs57IKGVRhvNMAx451FKj/c2xoi1FTEGQKG1pqosVVWP56vyPS8XIKZEXdc45zg6PibFiHMOn9dNipEUU773abpX2qCAGOWaK7Q8J9/vGCPGWLz3498AxOBZtC3WVhy6A0ZP6zqEgLGGECIATdOglCKEgPeemCLee9qmkc+XXzelCHmt+eAJIeR1kq/lgz2VxgU8fSZZH2m25uTupyS/UUqD0iQgxIgLUT671vI8GO/bfL3Oj/Joyichr52mda0UWqvxOpIgpjjul5S/n2+dmPdTZS0hBlJKo12Q588+M9P+HPdzKns+oJTGGjPuh5RkT+RTmfb27NIVeySfXfaKtZaY4gfvq1D5fqXx9zBdB6MN1hpSghACKUW01uPrgpr+Tsk9DDECCZOfBzFfA1ivlhijMEZjjM6nLLYtAT//+a8vUkpPHt6nvy6n8LFVcc8ip5T+G+C/AVgsF+mL3/sJJJX/VAFGnhclmVmv11hr8N5xOBzYHzZorWnbmkRgGAb6wRGiYnW0xnvPo0ePePv2LVVVsVwu6fseYLyJy+WSEAIhBNbrNYvFgrZt+e677xj2YtDKza6qirqu0VrjfcA5T4wJa2xezJYQwoMFIjcjpYQ2GghicLxs3GnxJqytMKbCGou1NZVdjI+BGCBiQqVI2zZ0fY/zjrquMcYQU8R5h9IK7z1Ki/FUKjuMfN210ViliQFSTGilMMZye3vDyfEZznuauiXGSN/3eB949PgR3333ipPjY/74j/4mLz/7jM1mwz/7Z/+Mdr3ker/lP/n7f593b16xub3l+GjNT3/nC/7iz/4U13VcvHnNsqp59ugRTx6dsawbzk+OWDQt/WHHsmk5HA4kFC8//5ybuzsG1+G942//7T8hpci7d2/Y7uSen5yccHFxweXFNcYYFoslYNhutywXa5xzKKWpqorz83N+/OMf89133/HNd99ydHTEbr8To6MU+67j+OSYtm356Ze/x+eff87NzQ2vXr2i6zo2t7ds7+5YL5Y471EoYoJD1/PkyTP63uNDYNEuqesWlKbrHUop6rrBec/h0Mt11uD7njD0/L2/9/f47rvv2B0OoyM/9B3r9ZqooG1bnPdYa0Epbm9vefX6FcfHxyzbBZVS2PJlDYtli60sN9dX3G1u8d6NfqCuWxbtEmMMwScG5xmGQdZw8CzWFSEOpCAGR2mD0RalDCEkhgjatkRTMUTYHBzvN1t2fYC6JSqDD7Lu66oat7tSaoQiUkqQDW3btnjv8G4Qp5kSRinquqJtW6qqYt8d6Pqevu9xzuW1K69ljMlGEpwbGIaBk5Mj9vs9zjkWiwWLRYsPjhACw+CorMXkAFKcjPw/dB1936PzNV+0LUolDocDRmtWqxUheHa7HbvdgZhgtVqxXC5xzrHdbum6DqUUq9WKs7MzrLVsNhv6vh8NugSvjEHhMPR470aHD3B0dMTR0REAff7s3nsWiwXn5+dUVUXXdXjvx8Bjv98TomOxqKkrjVEBowOrVcV/+nf/BotWcXa65HjdYk0iRpcDv8j/6H/8v/3Nx4z3X5dT+Bb40eznz4FX3/fklKOGEj1MjoHsSiRasla8cNs2+NCPBh3iGM2YHGmY7PGrqpLo1NrsfcU8eu+JOSoszmK9XvPy5Us2mw3bCEPfj3+TUhqNfgjF8M8iwDhlJPPPFUL4IMos/5eooXzm8liJfkuUoMjRrUqkeD/Cuncdc2iijUFpNTsXiXxAHEHUKX8GiEoR8YSUcMGjjcZUFgP4GBm8Z3ADVV2jjOHNu7e8+OwzfvTjH/P8V7/iu7evWa1W9EPPj3/8Y7Z3d+x3W66urogxopUeoxxjDNZYfL6OTdOwub3m0dkZSikOeRNJJuIIweUNx3iftNa8ffsWYwyr1ZIQ4pgN1nU93gNx3I6qqjgcDoQQaJqG5XJJ13eEGFE5e9nv9qyWS96+fs3L5y9omoa6rnNU3tLbvWRS+bVjkCAkBI8xCjcExAhCiIG+72ialpQSjx894m6zYbfbQYoslwt+8+YVNzc3aKPo+45EwlZitLx3LFYrHj0659Wr12y7A1pr+q7DDT1np2dsbm9ISuFiRCPGrKqtGFrv0DnQDjlrADDa0DQt2hiqZEixIoZEjJ6YAjGWDBlUApLGWAksrNYESlQsW1NrjbGKKKnBve0ac5QLSFaYM4+yZrVR6KTQIWf6KWFm6zXGKUJOKY1G06Du7S8JxrQ4TsRZlKCsZFbOOYzR1E1Nbey4V+Va+XHntO2Cyso1TCliraWuqmwjxHlJNl72ds5Qsn2p65qjoyNWqxVXV1ccDof83uLAqqrCe4/3Pv9sSWlau8VelWCxrmtSSnRdl/dCoKqqD+xG+fLeS7ZaKSot75kSeC+ZbvABlZEErfS96/jw+OtyCv8C+FIp9RPgO+B/Bfxvvv/pZUGW1H5K8ctyKje7qiQyr11N13WyWTUZKlAkDN57qqpiGIbxQoYQRhjEGDM6FOfcaDRijDx+/JhXr17hup6UDeloULNjSGlakNMN+hjElMbNMT/mz5lfg/Khy6+nFHfaUkqrjJA9fI3JqRhjUHrKEmIUgyXnI5FuyLCCUooQI1VdS8Rb19mpaBrAx0CIkeVqhbGWN+/ecugO/N7j3+PFZy/46pvfsDo/5eb6mt/5/DOWbcPXv97x+s0bwjCQnB9TeRBIgwKHRIFDloslbdOit1t8Nt79UDEMHTc3N8QYuLu7w3nHyckJr1+/5vHjxywWCw6HTmAhJTBQiU2dcwxDz36/49Ad0EZzfHzMyckxh+5A13ekxJhlbTdbur7j3du31E0NKdEfOonY2wX7/R6tzXjf66piGHqaZiGbOwZiCqQIPjhssMRkOT8/x1aWGAPdfo9WsFg0vHn7hrqWACchm3oYBmxlIEZCcCzahhA9h8OBYeg5OloRvCN4J7BbCGJQjWEYBoahYxh6tNE0psY5JRl0343LqziGui6LLBCTJ6VIzBG0hgmSMpZKGUiahMFEhY2Kpm2IBry2BKXQ2amoBAQm2Ewu8mj0ZQ/5cT/rDH8arbNhlJW+Xq9p2nZ8zDmXvVUJHKegSGuV4TojjjAGnBOopO97Tk9PaeqaytoMD07wqjWGWFcZkks5K0m0bYvVmq7rGIYBBSzaBmOb0Y4UO9M0DYvFYoT7bm9vcc5lSFbexxg
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Let's take a look at the dataset\n",
"import mmcv\n",
"import matplotlib.pyplot as plt\n",
"\n",
"img = mmcv.imread('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')\n",
"plt.figure(figsize=(8, 6))\n",
"plt.imshow(mmcv.bgr2rgb(img))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Support a new dataset\n",
"\n",
"Datasets in MMClassification require image and ground-truth labels to be placed in folders with the same perfix. To support a new dataset, there're two ways to generate a customized dataset. \n",
"\n",
"The simplest way is to convert the dataset to existing dataset formats (ImageNet). The other way is to add new Dataset class. More details can be found [here](https://github.com/open-mmlab/mmclassification/blob/master/docs/tutorials/new_dataset.md).\n",
"\n",
"In this tutorials, we'll show the details about both of the methods."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Reorganize dataset to existing format\n",
"\n",
"This is the simplest way to support a new dataset. To do this, there're two steps:\n",
"\n",
"1. Reorganize the structure of customized dataset to the existing dataset formats.\n",
"2. Generate annotation files accordingly.\n",
"\n",
"Here we take \"Cats and Dogs Dataset\" as an example. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's reorganize the structure. Before converting it into the format of ImageNet, let's have a quit look about the structure of ImageNet. \n",
"\n",
"For training, it differentiates classes by folders, i.e. images with the same label should be in the same folder and all the folders of different classes should be in one folder:\n",
"\n",
"```\n",
"imagenet\n",
"├── ...\n",
"├── train\n",
"│ ├── n01440764\n",
"│ │ ├── n01440764_10026.JPEG\n",
"│ │ ├── n01440764_10027.JPEG\n",
"│ │ ├── ...\n",
"│ ├── ...\n",
"│ ├── n15075141\n",
"│ │ ├── n15075141_999.JPEG\n",
"│ │ ├── n15075141_9993.JPEG\n",
"│ │ ├── ...\n",
"```\n",
"\n",
"\n",
"Luckily, our training dataset has similar structure and we don't have to do anything on it.\n",
"\n",
"Note: The `ImageNet` dataset class in MMCls will scan the directory of the training set and map each folders, i.e. the class name, to its label automatically. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For validating, we need to extract validation dataset from our training dataset. Here's how we split the dataset."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import shutil\n",
"import os\n",
"import os.path as osp\n",
"\n",
"\n",
"data_root = './data/cats_dogs_dataset/'\n",
"train_dir = osp.join(data_root, 'training_set/training_set/')\n",
"val_dir = osp.join(data_root, 'val_set/val_set/')\n",
"\n",
"# Split train/val set\n",
"mmcv.mkdir_or_exist(val_dir)\n",
"class_dirs = [\n",
" d for d in os.listdir(train_dir) if osp.isdir(osp.join(train_dir, d))\n",
"]\n",
"for cls_dir in class_dirs:\n",
" train_imgs = [filename for filename in mmcv.scandir(osp.join(train_dir, cls_dir), suffix='.jpg')]\n",
" # Select first 4/5 as train set and the last 1/5 as val set\n",
" train_length = int(len(train_imgs)*4/5)\n",
" val_imgs = train_imgs[train_length:]\n",
" # Move the val set into a new dir\n",
" src_dir = osp.join(train_dir, cls_dir)\n",
" tar_dir = osp.join(val_dir, cls_dir)\n",
" mmcv.mkdir_or_exist(tar_dir)\n",
" for val_img in val_imgs:\n",
" shutil.move(osp.join(src_dir, val_img), osp.join(tar_dir, val_img))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For test set, there's nothing to change. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Second, we need to generate the annotations for validation and test dataset. The classes of the dataset are also needed."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import shutil\n",
"import os\n",
"import os.path as osp\n",
"\n",
"from itertools import chain\n",
"\n",
"\n",
"# Generate mapping from class_name to label\n",
"def find_folders(root_dir):\n",
" folders = [\n",
" d for d in os.listdir(root_dir) if osp.isdir(osp.join(root_dir, d))\n",
" ]\n",
" folders.sort()\n",
" folder_to_idx = {folders[i]: i for i in range(len(folders))}\n",
" return folder_to_idx\n",
"\n",
"\n",
"# Generate annotations\n",
"def gen_annotations(root_dir):\n",
" annotations = dict()\n",
" folder_to_idx = find_folders(root_dir)\n",
" \n",
" for cls_dir, label in folder_to_idx.items():\n",
" cls_to_label = [\n",
" '{} {}'.format(osp.join(cls_dir, filename), label) \n",
" for filename in mmcv.scandir(osp.join(root_dir, cls_dir), suffix='.jpg')\n",
" ]\n",
" annotations[cls_dir] = cls_to_label\n",
" return annotations\n",
"\n",
"\n",
"data_root = './data/cats_dogs_dataset/'\n",
"val_dir = osp.join(data_root, 'val_set/val_set/')\n",
"test_dir = osp.join(data_root, 'test_set/test_set/')\n",
" \n",
"# Save val annotations\n",
"with open(osp.join(data_root, 'val.txt'), 'w') as f:\n",
" annotations = gen_annotations(val_dir)\n",
" contents = chain(*annotations.values())\n",
" f.writelines('\\n'.join(contents))\n",
" \n",
"# Save test annotations\n",
"with open(osp.join(data_root, 'test.txt'), 'w') as f:\n",
" annotations = gen_annotations(test_dir)\n",
" contents = chain(*annotations.values())\n",
" f.writelines('\\n'.join(contents))\n",
"\n",
"# Generate classes\n",
"folder_to_idx = find_folders(train_dir)\n",
"classes = list(folder_to_idx.keys())\n",
"with open(osp.join(data_root, 'classes.txt'), 'w') as f:\n",
" f.writelines('\\n'.join(classes))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each line of the annotation list contains a filename and its corresponding ground-truth label. The format is as follows:\n",
"\n",
"```\n",
"...\n",
"cats/cat.3769.jpg 0\n",
"cats/cat.882.jpg 0\n",
"...\n",
"dogs/dog.3881.jpg 1\n",
"dogs/dog.3377.jpg 1\n",
"...\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Implement a customized dataset\n",
"\n",
"NOTE: If you choose to use the first method, please SKIP the following codes.\n",
"\n",
"The second method to support a new dataset is to write a new Dataset class `CatsDogsDataset`. In this method, we don't have to change the structure of the dataset. The following steps are needed:\n",
"\n",
"1. Generate class list. Each line is the class name. E.g.\n",
" ```\n",
" cats\n",
" dogs\n",
" ```\n",
"2. Generate train/validation/test annotations. Each line contains a filename and its corresponding ground-truth label.\n",
" ```\n",
" ...\n",
" cats/cat.436.jpg 0\n",
" cats/cat.383.jpg 0\n",
" ...\n",
" dogs/dog.1340.jpg 1\n",
" dogs/dog.1660.jpg 1\n",
" ...\n",
" ```\n",
"3. Implement `CatsDogsDataset` inherited from `BaseDataset`, and overwrite `load_annotations(self)`,\n",
"like [CIFAR10](https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/cifar.py) and [ImageNet](https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/imagenet.py).\n",
"\n",
"First, let's generate class list and annotation files."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Generate annotations\n",
"import os\n",
"import mmcv\n",
"import os.path as osp\n",
"\n",
"from itertools import chain\n",
"\n",
"\n",
"# Generate mapping from class_name to label\n",
"def find_folders(root_dir):\n",
" folders = [\n",
" d for d in os.listdir(root_dir) if osp.isdir(osp.join(root_dir, d))\n",
" ]\n",
" folders.sort()\n",
" folder_to_idx = {folders[i]: i for i in range(len(folders))}\n",
" return folder_to_idx\n",
"\n",
"\n",
"# Generate annotations\n",
"def gen_annotations(root_dir):\n",
" annotations = dict()\n",
" folder_to_idx = find_folders(root_dir)\n",
" \n",
" for cls_dir, label in folder_to_idx.items():\n",
" cls_to_label = [\n",
" '{} {}'.format(osp.join(cls_dir, filename), label) \n",
" for filename in mmcv.scandir(osp.join(root_dir, cls_dir), suffix='.jpg')\n",
" ]\n",
" annotations[cls_dir] = cls_to_label\n",
" return annotations\n",
"\n",
"\n",
"data_root = './data/cats_dogs_dataset/'\n",
"train_dir = osp.join(data_root, 'training_set/training_set/')\n",
"test_dir = osp.join(data_root, 'test_set/test_set/')\n",
"\n",
"# Generate class list\n",
"folder_to_idx = find_folders(train_dir)\n",
"classes = list(folder_to_idx.keys())\n",
"with open(osp.join(data_root, 'classes.txt'), 'w') as f:\n",
" f.writelines('\\n'.join(classes))\n",
" \n",
"# Generate train/val set randomly\n",
"annotations = gen_annotations(train_dir)\n",
"# Select first 4/5 as train set\n",
"train_length = lambda x: int(len(x)*4/5)\n",
"train_annotations = map(lambda x:x[:train_length(x)], annotations.values())\n",
"val_annotations = map(lambda x:x[train_length(x):], annotations.values())\n",
"# Save train/val annotations\n",
"with open(osp.join(data_root, 'train.txt'), 'w') as f:\n",
" contents = chain(*train_annotations)\n",
" f.writelines('\\n'.join(contents))\n",
"with open(osp.join(data_root, 'val.txt'), 'w') as f:\n",
" contents = chain(*val_annotations)\n",
" f.writelines('\\n'.join(contents))\n",
" \n",
"# Save test annotations\n",
"test_annotations = gen_annotations(test_dir)\n",
"with open(osp.join(data_root, 'test.txt'), 'w') as f:\n",
" contents = chain(*test_annotations.values())\n",
" f.writelines('\\n'.join(contents))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we need to implement `load_annotations` function in the new dataset class `CatsDogsDataset`.\n",
"\n",
"Typically, this function returns a list, where each sample is a dict, containing necessary data informations, e.g., `img_path` and `gt_label`. These will be used by `mmcv.runner` during training to load samples. "
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "WnGZfribFHCx"
},
"outputs": [],
"source": [
"import mmcv\n",
"import numpy as np\n",
"\n",
"from mmcls.datasets import DATASETS, BaseDataset\n",
"\n",
"\n",
"# Regist model so that we can access the class through str in configs\n",
"@DATASETS.register_module()\n",
"class CatsDogsDataset(BaseDataset):\n",
"\n",
" def load_annotations(self):\n",
" assert isinstance(self.ann_file, str)\n",
"\n",
" data_infos = []\n",
" with open(self.ann_file) as f:\n",
" # The ann_file is the annotation files we generate above.\n",
" samples = [x.strip().split(' ') for x in f.readlines()]\n",
" for filename, gt_label in samples:\n",
" info = {'img_prefix': self.data_prefix}\n",
" info['img_info'] = {'filename': filename}\n",
" info['gt_label'] = np.array(gt_label, dtype=np.int64)\n",
" data_infos.append(info)\n",
" return data_infos"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "yUVtmn3Iq3WA"
},
"source": [
"### Modify configs\n",
"\n",
"In this part, we need to modify the config for finetune. \n",
"\n",
"In MMCls, the configs usually look like this:\n",
"\n",
"```\n",
"# 'configs/resnet/resnet50_b32x8_imagenet.py'\n",
"_base_ = [\n",
" # Model config\n",
" '../_base_/models/resnet50.py',\n",
" # Dataset config\n",
" '../_base_/datasets/imagenet_bs32.py',\n",
" # Schedule config\n",
" '../_base_/schedules/imagenet_bs256.py',\n",
" # Runtime config\n",
" '../_base_/default_runtime.py'\n",
"]\n",
"```\n",
"\n",
"A standard configuration in MMCls contains four parts:\n",
"\n",
"1. Model config, which specify the basic structure of the model, e.g. number of the input channels.\n",
"2. Dataset config, which contains details about the dataset, e.g. type of the dataset.\n",
"3. Schedule config, which specify the training schedules, e.g. learning rate.\n",
"4. Runtime config, which contains the rest of details, e.g. log config.\n",
"\n",
"In this part, we'll show how to modify config in python files. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's load the existing config file."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Wwnj9tRzqX_A"
},
"outputs": [],
"source": [
"# Load the existing config file\n",
"from mmcv import Config\n",
"cfg = Config.fromfile('configs/resnet/resnet50_b32x8_imagenet.py')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1y2oV5w97jQo"
},
"source": [
"Then, we'll modify it according to the method we support our new dataset. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Modify config after reorganization of dataset\n",
"If you reorganize the dataset and convert it into ImageNet format, there's little things to change."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First is the model configs. The classification head would be reconstructed if `num_classes` is different from the pretrained one."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Modify num classes of the model in classification head\n",
"cfg.model.head.num_classes = 2\n",
"cfg.model.head.topk = (1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Second is the dataset configs. As we reorganize the dataset into the format of ImageNet, we can use most of the configurations of it. Note that the training annotations don't need to specify, because it can find the mappings through the structure of the dataset."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# Modify the number of workers according to your computer\n",
"cfg.data.samples_per_gpu = 32\n",
"cfg.data.workers_per_gpu=2\n",
"# Modify the image normalization configs \n",
"cfg.img_norm_cfg = dict(\n",
" mean=[124.508, 116.050, 106.438], std=[58.577, 57.310, 57.437], to_rgb=True)\n",
"# Specify the path to training set\n",
"cfg.data.train.data_prefix = 'data/cats_dogs_dataset/training_set/training_set'\n",
"cfg.data.train.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"# Specify the path to validation set\n",
"cfg.data.val.data_prefix = 'data/cats_dogs_dataset/val_set/val_set'\n",
"cfg.data.val.ann_file = 'data/cats_dogs_dataset/val.txt'\n",
"cfg.data.val.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"# Specify the path to test set\n",
"cfg.data.test.data_prefix = 'data/cats_dogs_dataset/test_set/test_set'\n",
"cfg.data.test.ann_file = 'data/cats_dogs_dataset/test.txt'\n",
"cfg.data.test.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"# Modify the metric method\n",
"cfg.evaluation['metric_options']={'topk': (1)}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Third is the schedules of finetuning."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Optimizer\n",
"cfg.optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"cfg.optimizer_config = dict(grad_clip=None)\n",
"# Learning policy\n",
"cfg.lr_config = dict(policy='step', step=[1])\n",
"cfg.runner = dict(type='EpochBasedRunner', max_epochs=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, let's modify the configuration during run time."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# Load the pretrained weights\n",
"cfg.load_from = 'checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth'\n",
"# Set up working dir to save files and logs.\n",
"cfg.work_dir = './work_dirs/cats_dogs_dataset'\n",
"\n",
"from mmcls.apis import set_random_seed\n",
"# Set seed thus the results are more reproducible\n",
"cfg.seed = 0\n",
"set_random_seed(0, deterministic=False)\n",
"cfg.gpu_ids = range(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Have a look on the new configuration! "
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Config:\n",
"model = dict(\n",
" type='ImageClassifier',\n",
" backbone=dict(\n",
" type='ResNet',\n",
" depth=50,\n",
" num_stages=4,\n",
" out_indices=(3, ),\n",
" style='pytorch'),\n",
" neck=dict(type='GlobalAveragePooling'),\n",
" head=dict(\n",
" type='LinearClsHead',\n",
" num_classes=2,\n",
" in_channels=2048,\n",
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
" topk=1))\n",
"dataset_type = 'ImageNet'\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.05, 106.438], std=[58.577, 57.31, 57.437], to_rgb=True)\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
"]\n",
"data = dict(\n",
" samples_per_gpu=32,\n",
" workers_per_gpu=2,\n",
" train=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'),\n",
" val=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file='data/cats_dogs_dataset/val.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'),\n",
" test=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file='data/cats_dogs_dataset/test.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'))\n",
"evaluation = dict(interval=1, metric='accuracy', metric_options=dict(topk=1))\n",
"optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"lr_config = dict(policy='step', step=[1])\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"checkpoint_config = dict(interval=1)\n",
"log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])\n",
"dist_params = dict(backend='nccl')\n",
"log_level = 'INFO'\n",
"load_from = 'checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth'\n",
"resume_from = None\n",
"workflow = [('train', 1)]\n",
"work_dir = './work_dirs/cats_dogs_dataset'\n",
"seed = 0\n",
"gpu_ids = range(0, 1)\n",
"\n"
]
}
],
"source": [
"# Let's have a look at the final config used for finetuning\n",
"print(f'Config:\\n{cfg.pretty_text}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Modify config after implementing a customized dataset\n",
"\n",
"NOTE: If you choose to use the first method, please SKIP the following codes.\n",
"\n",
"As we implement a new dataset, there're something different from above:\n",
"1. The new dataset type should be specified.\n",
"2. The training annotations should be specified."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's have a look about the dataset configurations."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Specify the new dataset class\n",
"cfg.dataset_type = 'CatsDogsDataset'\n",
"cfg.data.train.type = cfg.dataset_type\n",
"cfg.data.val.type = cfg.dataset_type\n",
"cfg.data.test.type = cfg.dataset_type\n",
"\n",
"# Specify the training annotations\n",
"cfg.data.train.ann_file = 'data/cats_dogs_dataset/train.txt'\n",
"\n",
"# The followings are the same as above\n",
"cfg.data.samples_per_gpu = 32\n",
"cfg.data.workers_per_gpu=2\n",
"\n",
"cfg.img_norm_cfg = dict(\n",
" mean=[124.508, 116.050, 106.438], std=[58.577, 57.310, 57.437], to_rgb=True)\n",
"\n",
"cfg.data.train.data_prefix = 'data/cats_dogs_dataset/training_set/training_set'\n",
"cfg.data.train.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"\n",
"cfg.data.val.data_prefix = 'data/cats_dogs_dataset/training_set/training_set'\n",
"cfg.data.val.ann_file = 'data/cats_dogs_dataset/val.txt'\n",
"cfg.data.val.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"\n",
"cfg.data.test.data_prefix = 'data/cats_dogs_dataset/test_set/test_set'\n",
"cfg.data.test.ann_file = 'data/cats_dogs_dataset/test.txt'\n",
"cfg.data.test.classes = 'data/cats_dogs_dataset/classes.txt'\n",
"# Modify the metric method\n",
"cfg.evaluation['metric_options']={'topk': (1)}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Other configurations are the same as those mentioned above. And we just list them here."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"colab_type": "code",
"id": "eyKnYC1Z7iCV",
"outputId": "a25241e2-431c-4944-b0b8-b9c792d5aadd",
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Config:\n",
"model = dict(\n",
" type='ImageClassifier',\n",
" backbone=dict(\n",
" type='ResNet',\n",
" depth=50,\n",
" num_stages=4,\n",
" out_indices=(3, ),\n",
" style='pytorch'),\n",
" neck=dict(type='GlobalAveragePooling'),\n",
" head=dict(\n",
" type='LinearClsHead',\n",
" num_classes=2,\n",
" in_channels=2048,\n",
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
" topk=1))\n",
"dataset_type = 'CatsDogsDataset'\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.05, 106.438], std=[58.577, 57.31, 57.437], to_rgb=True)\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
"]\n",
"data = dict(\n",
" samples_per_gpu=32,\n",
" workers_per_gpu=2,\n",
" train=dict(\n",
" type='CatsDogsDataset',\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
" ],\n",
" ann_file='data/cats_dogs_dataset/train.txt',\n",
" classes='data/cats_dogs_dataset/classes.txt'),\n",
" val=dict(\n",
" type='CatsDogsDataset',\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" ann_file='data/cats_dogs_dataset/val.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'),\n",
" test=dict(\n",
" type='CatsDogsDataset',\n",
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file='data/cats_dogs_dataset/test.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'))\n",
"evaluation = dict(interval=1, metric='accuracy', metric_options=dict(topk=1))\n",
"optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"lr_config = dict(policy='step', step=[1])\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"checkpoint_config = dict(interval=1)\n",
"log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])\n",
"dist_params = dict(backend='nccl')\n",
"log_level = 'INFO'\n",
"load_from = 'checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth'\n",
"resume_from = None\n",
"workflow = [('train', 1)]\n",
"work_dir = './work_dirs/cats_dogs_dataset'\n",
"seed = 0\n",
"gpu_ids = range(0, 1)\n",
"\n"
]
}
],
"source": [
"# MODOL CONFIG\n",
"# Modify num classes of the model in classification head\n",
"cfg.model.head.num_classes = 2\n",
"cfg.model.head.topk = (1)\n",
"\n",
"# SCHEDULE CONFIG\n",
"# Optimizer\n",
"cfg.optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"cfg.optimizer_config = dict(grad_clip=None)\n",
"# Learning policy\n",
"cfg.lr_config = dict(policy='step', step=[1])\n",
"cfg.runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"\n",
"# RUNTIME CONFIG\n",
"# Load the pretrained weights\n",
"cfg.load_from = 'checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth'\n",
"# Set up working dir to save files and logs.\n",
"cfg.work_dir = './work_dirs/cats_dogs_dataset'\n",
"from mmcls.apis import set_random_seed\n",
"# Set seed thus the results are more reproducible\n",
"cfg.seed = 0\n",
"set_random_seed(0, deterministic=False)\n",
"cfg.gpu_ids = range(1)\n",
"\n",
"# Let's have a look at the final config used for training\n",
"print(f'Config:\\n{cfg.pretty_text}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QWuH14LYF2gQ"
},
"source": [
"### Finetune\n",
"\n",
"Now we finetune on our own dataset. As we've modified the training schedules, we can use the `train_model` API to finetune our model. "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 953,
"referenced_widgets": [
"40a3c0b2c7a44085b69b9c741df20b3e",
"ec96fb4251ea4b8ea268a2bc62b9c75b",
"dae4b284c5a944639991d29f4e79fac5",
"c78567afd0a6418781118ac9f4ecdea9",
"32b7d27a143c41b5bb90f1d8e66a1c67",
"55d75951f51c4ab89e32045c3d6db8a4",
"9d29e2d02731416d9852e9c7c08d1665",
"1bb2b93526cd421aa5d5b86d678932ab"
]
},
"colab_type": "code",
"id": "jYKoSfdMF12B",
"outputId": "1c0b5a11-434b-4c96-a4aa-9d685fff0856"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-11 17:17:38,573 - mmcls - INFO - load checkpoint from checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth\n",
"2021-03-11 17:17:38,574 - mmcls - INFO - Use load_from_local loader\n",
"2021-03-11 17:17:38,625 - mmcls - WARNING - The model and loaded state dict do not match exactly\n",
"\n",
"size mismatch for head.fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([2, 2048]).\n",
"size mismatch for head.fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).\n",
"2021-03-11 17:17:38,626 - mmcls - INFO - Start running, host: SENSETIME\\shaoyidi@CN0014004140L, work_dir: /home/SENSETIME/shaoyidi/VirtualenvProjects/add_tutorials/MMCls_Tutorials/mmclassification/work_dirs/cats_dogs_dataset\n",
"2021-03-11 17:17:38,626 - mmcls - INFO - workflow: [('train', 1)], max: 2 epochs\n",
"/home/SENSETIME/shaoyidi/anaconda3/lib/python3.8/site-packages/mmcv/runner/hooks/logger/text.py:55: DeprecationWarning: an integer is required (got type float). Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.\n",
" mem_mb = torch.tensor([mem / (1024 * 1024)],\n",
"2021-03-11 17:18:22,741 - mmcls - INFO - Epoch [1][100/201]\tlr: 1.000e-02, eta: 0:02:12, time: 0.439, data_time: 0.023, memory: 2961, loss: 0.6723, top-1: 68.5312\n",
"2021-03-11 17:19:04,455 - mmcls - INFO - Epoch [1][200/201]\tlr: 1.000e-02, eta: 0:01:26, time: 0.417, data_time: 0.004, memory: 2961, loss: 0.5848, top-1: 66.3125\n",
"2021-03-11 17:19:04,521 - mmcls - INFO - Saving checkpoint at 1 epochs\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[>>>>>>>>>>>>>>>>>>>>>>>>>>>] 1601/1601, 211.0 task/s, elapsed: 8s, ETA: 0s"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-11 17:19:12,501 - mmcls - INFO - Epoch(val) [1][201]\taccuracy: 64.3973\n",
"2021-03-11 17:19:56,609 - mmcls - INFO - Epoch [2][100/201]\tlr: 1.000e-03, eta: 0:00:43, time: 0.439, data_time: 0.023, memory: 2961, loss: 0.4877, top-1: 74.7188\n",
"2021-03-11 17:20:38,827 - mmcls - INFO - Epoch [2][200/201]\tlr: 1.000e-03, eta: 0:00:00, time: 0.422, data_time: 0.004, memory: 2961, loss: 0.4244, top-1: 78.0625\n",
"2021-03-11 17:20:38,893 - mmcls - INFO - Saving checkpoint at 2 epochs\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[>>>>>>>>>>>>>>>>>>>>>>>>>>>] 1601/1601, 213.9 task/s, elapsed: 7s, ETA: 0s"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-03-11 17:20:46,778 - mmcls - INFO - Epoch(val) [2][201]\taccuracy: 88.0075\n"
]
}
],
"source": [
"import time\n",
"\n",
"from mmcls.datasets import build_dataset\n",
"from mmcls.models import build_classifier\n",
"from mmcls.apis import train_model\n",
"\n",
"# Create work_dir\n",
"mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))\n",
"# Build the classifier\n",
"model = build_classifier(cfg.model)\n",
"# Build the dataset\n",
"datasets = [build_dataset(cfg.data.train)]\n",
"# Add an attribute for visualization convenience\n",
"model.CLASSES = datasets[0].CLASSES\n",
"# Begin finetuning\n",
"train_model(\n",
" model,\n",
" datasets,\n",
" cfg,\n",
" distributed=False,\n",
" validate=True,\n",
" timestamp=time.strftime('%Y%m%d_%H%M%S', time.localtime()),\n",
" meta=dict())"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DEkWOP-NMbc_"
},
"source": [
"Let's checkout our trained model."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 645
},
"colab_type": "code",
"id": "ekG__UfaH_OU",
"outputId": "ac1eb835-19ed-48e6-8f77-e6d325b915c4"
},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 576x432 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAm4AAAJCCAYAAAB5xkteAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9ebCua3YXhv2e4R2+YU/nnHvOnXu4re5WS2qBAImkDTZmMCkgJHYSJCeh7IRgXIRQRSjKCVQGx+VyUkmFUOU/bMCOoQI2lRjbgAMyxDE2FEFYCElNt9S3W7fveMY9fcM7PUP+WGs9z/N+e5+rVkgnl+K8t/bdZ3/DOzzDWr/1W5OKMeLF8eJ4cbw4XhwvjhfHi+PF8ck/9P+/b+DF8eJ4cbw4XhwvjhfHi+PF8Z0dL4Dbi+PF8eJ4cbw4XhwvjhfHPyTHC+D24nhxvDheHC+OF8eL48XxD8nxAri9OF4cL44Xx4vjxfHieHH8Q3K8AG4vjhfHi+PF8eJ4cbw4Xhz/kBwvgNuL48Xx4nhxvDheHC+OF8c/JMd3DbgppX6zUurnlFJvK6X+pe/WdV4cL44Xx4vjxfHieHG8OP5ROdR3o46bUsoA+HkAvxHA+wB+AsCPxRj//v/XL/bieHG8OF4cL44Xx4vjxfGPyPHdYtx+GMDbMcZvxRhHAP8ugN/+XbrWi+PF8eJ4cbw4XhwvjhfHPxKH/S6d9zUA7xV/vw/gR573YWNtrKpq9ppK/1DzD8f5J7TRMMYgxogYAkIICDEixgBEQGkFpRS01lAAIp/AhwDECKX0/LTFtSOAGAK/oPhW+CxRbk8BMabvKwXEEBFjhLEWxmi6p0D3FCOglQKUonN/B4eCSvdw6/s8RgqKni+CxqNgU5VS6QeK3keMxfgW3+N/zwbj4H7kPHx3gFJ8n0pe4X/fPIni1+f3gnRdpYtzgN8vZuh5HHGahVj8++aH5mOmFPLcyrjlD6Z5Prh2Gtv0K0JrjaqqEUOA847Pq/K98fjLGpUpNdpAawWlNK+fgBgivHeIMYKGSUHzuBljoLRCZSukaYiA955ft4gxIngPpTSMNWk9xBgRQoQsKXmOxaJF3TTwziMEj0Mm3vNrIdAalrWrtKZrB5/GIo1DjHDOISIiBhofYw2sMbC24n3h4ZyH957Hp5iqYmwBwBoLpRWc9wCAtmmgtIZ3Dj4EdF0HAGiaBgpAkHs82LNpOwEw1qKyNUII6RkRkdZuvGUD0NgGRERYawEoeD8fM61lQdG/66oGiwr44DEOI82d1umcWmtoY0iOiTyLEVVlYYyFMQZaazjvEHzANE2zcTfWQCkNawyUAoKn5zd8DZF5MYa0FmOMfO8h3YPl9SMyIj1VEgsRSDtb8Xiq2fyFNB7FOlKy91Xad7NxhaxHei/yBokxIsQIHyK8rD+I/MjnSdtWpcvR70MdUhxZdgEyQUkGpmeOaTHSNeLs2fN5aIxjiLPPze7h4FZKuSx7Vr6b5oL3T0TM60SLXglpnEl+qLSGA8+pzKHoHdGFMz0XSd4AtH+V0jBGJzlMe9+ndamUgtIaRmtobXiO6NlDCOkZyrEXfVqqFqVJ5ikl+lnGJX9Jzi07ma6t0oflOdN88JzJs9M1FO0fpRgfRLjJsZ6m1+XCop/KsVGzm0LCGmketYbRhWCJgbED3WTbNqiriuU8+LeCNZrHKKZ1Vkrer3/9F57GGF/CwfHdAm637ZSZJlBK/W4AvxsAbGXx6c99ung3LxgCVjqdktZoXnbr9RrHx8eYphHTNGK/36Pve0zTAOcnNE0Day2WyxbWWvjgEILHbr+H9wFV1QBQ8JEWlOZJFgW73+9ZEVUJAHpPikZrAo0hBHjvYYyBMQbjOGIcR9y5cwfr9Rpd12EcRwzDAO89rLXQSmPqBwQ+D48JP2NMf8uCo3+b/D7LErkHpWgDyb1570lp8rmqqoIxBnVdwVq6Rx88CXUFVtakDOR5ynuQe5Q5qaqaXlcGWhsYY6G1hdEVCxZ6nTakoTllhGG1hlGaFZKGcy7dc4wRdVNDG5OWzTRNaflEMIiQ8eKlkEFFCVACIu9vzcAn+IgYgMoYGG2SMpSxd84jhOJcQfDtfG4AIAS5N/BaPMLrr7+BYRhwfn6ex1EpwGgErRCMggsek/eI3iE4h6P1Cou2xfH6CE3TYHt1gXHosbveYBpHaARoAMuqQWUNTo/XaJsGD+7eQ1VZ1EYjOI/tdoOmqnHv7hn85LDbbFA3DdZHR/AhYJwcJucwjCO0VtBGwbkJ3jt83/d9Hz796U/j/PwZttsNnJ8QQl6bwzDAOYftdotpmrDbdUAElsslYgS22y4pkrZd4OUHr8A5h4uLC16LHovFAqenpzg5OcH9+/ex3++x2WxwcXmJi8sLWGthrSXBisi/kdbxnbt3UDcNLq6uEGPEF774BSwWC1xeXWG73eLrX/86Qgh4/fXXoZRC39E9GrBClrkolOXJySnu3XuAvu+x3+/hXIDznte0QYgkewQkBE/Ad7PZwHuPszv3YKzF5cUlpmlCjApaKyzaBkCEHwYs2gZvvvE6AyqF6+trvPvuuzDWYrFapnVS1TWatoFzDtM0oR8HDCxHjk9OcHZ2htVqhYvLC+y7Dh999BF2u10an+V6hbquce/OXVhjMG73UABWywUQI7quQwgekxthjEZd1xjHAdfXV5imEX3fYbVa4fT0NN1D3k/gH9obCpr3tMioGlppGKsABOz3WzY8HAj+0v6zlgBoVTWsIA1iBBwbt5MPgLKArRGVRlAG/ejQjQ7XXYdN12N0wOQBXTfQtgKMAZTC5Gj9GY2ZzBL5iaxFZqDOasPPQGBomsb0nCF4BO/S2kaMQIjFc2vaSyxDuqHHxPIsFMo9yXg9N4istVBKoaoqeO+x223TelgsWty9ewfOuaQ7nHNYLBZYrVbo+w790BGA9x513aKyVZKn49jDefp3DBE1j33btqQ3eBC895jchKurSyilsF6vUVcVjlbrLBenCfv9Fs45jOOIqqpQ1zXatsVyucQ4OozThGEYMI5jenYZexmPvu/JwOT5WSwWsNaibdv0uXLuQgjp+UUXkh6r0xxN04hh6NN7pf6T8arrGkdHR7zmRzjn8PTpU4zjiNPTU9R1na7bNM0MA4iMN0knAeM0Yr/bJqTTtjVWqwUUArSKcFOHaeyhtYdSAd/7hbfwxuuvoG0N6kpjsbCoK4OzkzWauoIPA0Lw8F72HOnfX/2V3/lt3HJ8t4Db+wDeKP5+HcCH5QdijP8mgH8TANrFYm6W/RIOGVjnXGF9ZJQoAEQASYierXxa0NPkASiYukoLRxbUxAvRWoumaQBgBoZkU8rfxpj0uXLxyeaURTUMA0Y3JtYiiAVTALUMyDJrQ3IjZlAFXTBrASGoG0xJssCK8fK+ZLl+qSNOoxtCgFYaUfO/dWb4ElPCbFFkQSmWoXceAZ4NXJ3mxjmy+JTR0DwmxNrwYtZq/vwM3JTOTzBnPTQxm4qsTaM0ogYQVTINZKzHcWRQTcBWNqutiFGRexPFs16vGZiQsGvbFnfu3MXnPvcWxnHC5eUlnj17hm9/+9tYLJc4Oj3B1W6Lq80VXnpwH3dfegkfvv8enjx6BKMVgvd45cHLuHv3LtzQYRoHVFUFDaDbbRGCh40KKlZiBmIcR8QYULUNtFZYLpYwhsYzhoCqIlbr+voaxlrUTZuAizEa2ihMk8Y00brv+x7b7QZXV1fwwSHGgLquYYzBcrmEUgp7Nni6rkMMEU3TwBiDxaJNr3sGeKKAlNJoWxL0sq7FkDhcq4qZaB88tDEwSqFnAexlf/PaODo6wtHREe3XtsX01ltp3mIIBAK9hx9dMr5ijHCFpbzf7/Ho0UeF4ePgnAfg2SKuoI2FNhWU1tCKAPtyuWTFRDx3NuLAoKhB8A6b/TmCd+i7DovFAicnJ/DeY7lcom4anJydJpnQ9T2ePXuGxWKBxWIBaJqr/X6PrutQVRUbkPQs6/Uaxpik1K8vrwAAi6pBUzdAiNBKwfGYaA1obVDVC2iteW41pmlkg5TWeN/3ae4EkNA+ycxtCEAMKAD
"text/plain": [
"<Figure size 1080x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"img = mmcv.imread('data/cats_dogs_dataset/training_set/training_set/cats/cat.1.jpg')\n",
"\n",
"model.cfg = cfg\n",
"result = inference_model(model, img)\n",
"plt.figure(figsize=(8, 6))\n",
"show_result_pyplot(model, img, result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Command line tool usages\n",
"\n",
"MMCls also provide some useful command line tools, which can help us:\n",
"\n",
"1. Train models\n",
"2. Finetune models\n",
"3. Test models\n",
"4. Do inference \n",
"\n",
"As the process of training is similar to finetuning, we'll show the details about how to finetune, test, and do inference in this tutorials. More details can be found [here](https://github.com/open-mmlab/mmclassification/blob/master/docs/getting_started.md).\n",
"\n",
"### Finetune\n",
"\n",
"To finetune via command line, several steps are needed:\n",
"1. Prepare customized dataset.\n",
"2. Support new dataset in MMCls.\n",
"3. Modify configs and write into files.\n",
"4. Finetune using command line.\n",
"\n",
"The first and second step are similar to those mentioned above. In this part, we'll show the details of the last two steps.\n",
"\n",
"#### Modify configs in files\n",
"\n",
"To reuse the common parts among different configs, we support inheriting configs from multiple existing configs. To finetune a ResNet50 model, the new config needs to inherit `configs/_base_/models/resnet50.py` to build the basic structure of the model. To use the \"Cats and Dogs Dataset\", the new config can also simply inherit `configs/_base_/datasets/cats_dogs_dataset.py`. To customize the training schedules, the new config should inherit `configs/_base_/schedules/cats_dogs_finetune.py`. For runtime settings such as training schedules, the new config needs to inherit `configs/_base_/default_runtime.py`.\n",
"\n",
"The final config file should look like this:\n",
"\n",
"```\n",
"# Save to \"configs/resnet/resnet50_cats_dogs.py\"\n",
"_base_ = [\n",
" '../_base_/models/resnet50.py',\n",
" '../_base_/datasets/imagenet_bs32.py',\n",
" '../_base_/schedules/imagenet_bs256.py',\n",
" '../_base_/default_runtime.py'\n",
"]\n",
"```\n",
"\n",
"Besides, you can also choose to write the whole contents into one config file rather than use inheritance, e.g. `configs/mnist/lenet5.py`.\n",
"\n",
"Here, we take the settings of reorganizion as an example. You can try by yourself on the case of implementing a customized dataset. All you have to do is to write new configs which will overwrite the original ones. Now, let's check the details."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's modify the model configs and save them into `configs/_base_/models/resnet50_cats_dogs.py`. The new config needs to modify the head according to the class numbers of the new datasets. By only changing `num_classes` in the head, the weights of the pre-trained models are mostly reused except the final prediction head.\n",
"\n",
"```python\n",
"_base_ = ['./resnet50.py']\n",
"model = dict(\n",
" head=dict(\n",
" num_classes=2,\n",
" topk = (1)\n",
" ))\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Second is the dataset's configs. Don't forget to save them into `configs/_base_/datasets/cats_dogs_dataset.py`.\n",
"\n",
"```python\n",
"_base_ = ['./imagenet_bs32.py']\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.050, 106.438],\n",
" std=[58.577, 57.310, 57.437],\n",
" to_rgb=True)\n",
"\n",
"data = dict(\n",
" # Modify the number of workers according to your computer\n",
" samples_per_gpu = 32,\n",
" workers_per_gpu=2,\n",
" # Specify the path to training set\n",
" train = dict(\n",
" data_prefix = 'data/cats_dogs_dataset/training_set/training_set',\n",
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
" ),\n",
" # Specify the path to validation set\n",
" val = dict(\n",
" data_prefix = 'data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file = 'data/cats_dogs_dataset/val.txt',\n",
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
" ),\n",
" # Specify the path to test set\n",
" test = dict(\n",
" data_prefix = 'data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file = 'data/cats_dogs_dataset/test.txt',\n",
" classes = 'data/cats_dogs_dataset/classes.txt'\n",
" )\n",
")\n",
"# Modify the metric method\n",
"evaluation = dict(metric_options={'topk': (1)})\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Third is the training schedule. The finetuning hyperparameters vary from the default schedule. It usually requires smaller learning rate and less training epochs. Let's save it into `configs/_base_/schedules/cats_dogs_finetune.py`.\n",
"\n",
"```python\n",
"# optimizer\n",
"# lr is set for a batch size of 128\n",
"optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"# learning policy\n",
"lr_config = dict(\n",
" policy='step',\n",
" step=[1])\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, for the run time configs, we can simple use the defualt one and change nothing. We can now gather all the configs into one file and save it into `configs/resnet/resnet50_cats_dogs.py`.\n",
"```python\n",
"_base_ = [\n",
" '../_base_/models/resnet50_cats_dogs.py', '../_base_/datasets/cats_dogs_dataset.py',\n",
" '../_base_/schedules/cats_dogs_finetune.py', '../_base_/default_runtime.py'\n",
"]\n",
"\n",
"# Don't forget to load pretrained model. Set it as the abosolute path. \n",
"load_from = 'XXX/mmclassification/checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth'\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Finetune using command line\n",
"\n",
"We use `tools/train.py` to finetune the model:\n",
"\n",
"```\n",
"python tools/train.py ${CONFIG_FILE} [optional arguments]\n",
"```\n",
"\n",
"If you want to specify the working directory in the command, you can add an argument `--work_dir ${YOUR_WORK_DIR}`.\n",
"\n",
"Here, we take our `ResNet50` on `CatsDogsDataset` for example."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-03-11 17:02:24,433 - mmcls - INFO - Environment info:\n",
"------------------------------------------------------------\n",
"sys.platform: linux\n",
"Python: 3.8.5 (default, Sep 4 2020, 07:30:14) [GCC 7.3.0]\n",
"CUDA available: True\n",
"GPU 0: GeForce GTX 1060 6GB\n",
"CUDA_HOME: /usr\n",
"NVCC: Cuda compilation tools, release 10.2, V10.2.89\n",
"GCC: gcc (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609\n",
"PyTorch: 1.5.0+cu101\n",
"PyTorch compiling details: PyTorch built with:\n",
" - GCC 7.3\n",
" - C++ Version: 201402\n",
" - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications\n",
" - Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)\n",
" - OpenMP 201511 (a.k.a. OpenMP 4.5)\n",
" - NNPACK is enabled\n",
" - CPU capability usage: AVX2\n",
" - CUDA Runtime 10.1\n",
" - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37\n",
" - CuDNN 7.6.3\n",
" - Magma 2.5.2\n",
" - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF, \n",
"\n",
"TorchVision: 0.6.0+cu101\n",
"OpenCV: 4.5.1\n",
"MMCV: 1.2.7\n",
"MMCV Compiler: GCC 7.3\n",
"MMCV CUDA Compiler: 10.1\n",
"MMClassification: 0.9.0+f3b9380\n",
"------------------------------------------------------------\n",
"\n",
"2021-03-11 17:02:24,433 - mmcls - INFO - Distributed training: False\n",
"2021-03-11 17:02:24,563 - mmcls - INFO - Config:\n",
"model = dict(\n",
" type='ImageClassifier',\n",
" backbone=dict(\n",
" type='ResNet',\n",
" depth=50,\n",
" num_stages=4,\n",
" out_indices=(3, ),\n",
" style='pytorch'),\n",
" neck=dict(type='GlobalAveragePooling'),\n",
" head=dict(\n",
" type='LinearClsHead',\n",
" num_classes=2,\n",
" in_channels=2048,\n",
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
" topk=1))\n",
"dataset_type = 'ImageNet'\n",
"img_norm_cfg = dict(\n",
" mean=[124.508, 116.05, 106.438], std=[58.577, 57.31, 57.437], to_rgb=True)\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
"]\n",
"data = dict(\n",
" samples_per_gpu=32,\n",
" workers_per_gpu=2,\n",
" train=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/training_set/training_set',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', size=224),\n",
" dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='ToTensor', keys=['gt_label']),\n",
" dict(type='Collect', keys=['img', 'gt_label'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'),\n",
" val=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/val_set/val_set',\n",
" ann_file='data/cats_dogs_dataset/val.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'),\n",
" test=dict(\n",
" type='ImageNet',\n",
" data_prefix='data/cats_dogs_dataset/test_set/test_set',\n",
" ann_file='data/cats_dogs_dataset/test.txt',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='Resize', size=(256, -1)),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(\n",
" type='Normalize',\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True),\n",
" dict(type='ImageToTensor', keys=['img']),\n",
" dict(type='Collect', keys=['img'])\n",
" ],\n",
" classes='data/cats_dogs_dataset/classes.txt'))\n",
"evaluation = dict(interval=1, metric='accuracy', metric_options=dict(topk=1))\n",
"optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)\n",
"optimizer_config = dict(grad_clip=None)\n",
"lr_config = dict(policy='step', step=[1])\n",
"runner = dict(type='EpochBasedRunner', max_epochs=2)\n",
"checkpoint_config = dict(interval=1)\n",
"log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])\n",
"dist_params = dict(backend='nccl')\n",
"log_level = 'INFO'\n",
"load_from = '/home/SENSETIME/shaoyidi/VirtualenvProjects/add_tutorials/MMCls_Tutorials/mmclassification/checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth'\n",
"resume_from = None\n",
"workflow = [('train', 1)]\n",
"work_dir = 'work_dirs/resnet50_cats_dogs'\n",
"gpu_ids = range(0, 1)\n",
"\n",
"2021-03-11 17:02:26,361 - mmcls - INFO - load checkpoint from /home/SENSETIME/shaoyidi/VirtualenvProjects/add_tutorials/MMCls_Tutorials/mmclassification/checkpoints/resnet50_batch256_imagenet_20200708-cfb998bf.pth\n",
"2021-03-11 17:02:26,362 - mmcls - INFO - Use load_from_local loader\n",
"2021-03-11 17:02:26,422 - mmcls - WARNING - The model and loaded state dict do not match exactly\n",
"\n",
"size mismatch for head.fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([2, 2048]).\n",
"size mismatch for head.fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).\n",
"2021-03-11 17:02:26,424 - mmcls - INFO - Start running, host: SENSETIME\\shaoyidi@CN0014004140L, work_dir: /home/SENSETIME/shaoyidi/VirtualenvProjects/add_tutorials/MMCls_Tutorials/mmclassification/work_dirs/resnet50_cats_dogs\n",
"2021-03-11 17:02:26,424 - mmcls - INFO - workflow: [('train', 1)], max: 2 epochs\n",
"2021-03-11 17:03:10,368 - mmcls - INFO - Epoch [1][100/201]\tlr: 1.000e-02, eta: 0:02:12, time: 0.437, data_time: 0.023, memory: 2962, loss: 0.5598, top-1: 69.3125\n",
"2021-03-11 17:03:52,698 - mmcls - INFO - Epoch [1][200/201]\tlr: 1.000e-02, eta: 0:01:26, time: 0.423, data_time: 0.004, memory: 2962, loss: 0.3681, top-1: 78.6875\n",
"2021-03-11 17:03:52,765 - mmcls - INFO - Saving checkpoint at 1 epochs\n",
"[>>>>>>>>>>>>>>>>>>>>>>>>>>>] 1601/1601, 240.7 task/s, elapsed: 7s, ETA: 0s2021-03-11 17:03:59,601 - mmcls - INFO - Epoch(val) [1][201]\taccuracy: 92.6921\n",
"2021-03-11 17:04:43,478 - mmcls - INFO - Epoch [2][100/201]\tlr: 1.000e-03, eta: 0:00:43, time: 0.437, data_time: 0.023, memory: 2962, loss: 0.2715, top-1: 85.2500\n",
"2021-03-11 17:05:25,385 - mmcls - INFO - Epoch [2][200/201]\tlr: 1.000e-03, eta: 0:00:00, time: 0.419, data_time: 0.004, memory: 2962, loss: 0.2335, top-1: 87.6875\n",
"2021-03-11 17:05:25,449 - mmcls - INFO - Saving checkpoint at 2 epochs\n",
"[>>>>>>>>>>>>>>>>>>>>>>>>>>>] 1601/1601, 239.4 task/s, elapsed: 7s, ETA: 0s2021-03-11 17:05:32,313 - mmcls - INFO - Epoch(val) [2][201]\taccuracy: 95.3154\n"
]
}
],
"source": [
"!python tools/train.py configs/resnet/resnet50_cats_dogs.py --work-dir work_dirs/resnet50_cats_dogs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test models\n",
"\n",
"We use `tools/test.py` to test models:\n",
"\n",
"```\n",
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]\n",
"```\n",
"\n",
"We show several optional arguments we'll use here:\n",
"\n",
"- `--metrics`: Evaluation metrics, which depends on the dataset, e.g., accuracy.\n",
"- `--metric-options`: Custom options for evaluation, e.g. topk=1.\n",
"\n",
"Please refer to `tools.test.py` for details about optional arguments.\n",
"\n",
"Here's the example of our `ResNet50`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Use load_from_local loader\n",
"[>>>>>>>>>>>>>>>>>>>>>>>>>>>] 2023/2023, 238.7 task/s, elapsed: 8s, ETA: 0s\n",
"accuracy : 94.91\n"
]
}
],
"source": [
"!python tools/test.py configs/resnet/resnet50_cats_dogs.py work_dirs/resnet50_cats_dogs/latest.pth --metrics=accuracy --metric-options=topk=1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Do inference\n",
"\n",
"We can use the following commands to infer a dataset and save the results.\n",
"\n",
"```shell\n",
"python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}]\n",
"```\n",
"\n",
"Optional arguments:\n",
"\n",
"- `RESULT_FILE`: Filename of the output results. If not specified, the results will not be saved to a file.\n",
"\n",
"Here's the example of our `ResNet50`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Use load_from_local loader\n",
"[>>>>>>>>>>>>>>>>>>>>>>>>>>>] 2023/2023, 240.7 task/s, elapsed: 8s, ETA: 0stools/test.py:138: UserWarning: Evaluation metrics are not specified.\n",
" warnings.warn('Evaluation metrics are not specified.')\n",
"\n",
"writing results to results.json\n"
]
}
],
"source": [
"!python tools/test.py configs/resnet/resnet50_cats_dogs.py work_dirs/resnet50_cats_dogs/latest.pth --out=results.json"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"include_colab_link": true,
"name": "MMSegmentation Tutorial.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
},
"toc-autonumbering": true,
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"1bb2b93526cd421aa5d5b86d678932ab": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"32b7d27a143c41b5bb90f1d8e66a1c67": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"40a3c0b2c7a44085b69b9c741df20b3e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_dae4b284c5a944639991d29f4e79fac5",
"IPY_MODEL_c78567afd0a6418781118ac9f4ecdea9"
],
"layout": "IPY_MODEL_ec96fb4251ea4b8ea268a2bc62b9c75b"
}
},
"55d75951f51c4ab89e32045c3d6db8a4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9d29e2d02731416d9852e9c7c08d1665": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"c78567afd0a6418781118ac9f4ecdea9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1bb2b93526cd421aa5d5b86d678932ab",
"placeholder": "",
"style": "IPY_MODEL_9d29e2d02731416d9852e9c7c08d1665",
"value": " 97.8M/97.8M [00:10&lt;00:00, 9.75MB/s]"
}
},
"dae4b284c5a944639991d29f4e79fac5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "100%",
"description_tooltip": null,
"layout": "IPY_MODEL_55d75951f51c4ab89e32045c3d6db8a4",
"max": 102567401,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_32b7d27a143c41b5bb90f1d8e66a1c67",
"value": 102567401
}
},
"ec96fb4251ea4b8ea268a2bc62b9c75b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}