{ "cells": [ { "cell_type": "markdown", "id": "c6c6e585-6d14-4b6a-8621-6395518b463e", "metadata": { "tags": [] }, "source": [ "# EasyCV图像分类-SwinTransformer\n", "\n", "本文将介绍如何使用EasyCV快速使用[Swin Transformer](https://arxiv.org/abs/2103.14030) 进行图像分类模型的训练、推理" ] }, { "cell_type": "markdown", "id": "0512498b-a900-4969-9563-67029a1009e2", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ "## 运行环境要求\n", "\n", "PAI-Pytorch镜像 or 原生Pytorch1.5+以上环境 GPU机器, 内存32G以上" ] }, { "cell_type": "markdown", "id": "0cafa45c-b04b-4471-b8f0-9c88ce50cc0e", "metadata": { "tags": [] }, "source": [ "## 安装依赖包\n", "\n", "注: 在PAI-DSW docker中无需安装相关依赖,可跳过此部分 在本地notebook环境中执行\n" ] }, { "cell_type": "markdown", "id": "c8280983-39b8-414d-a712-c40fd68dadfe", "metadata": {}, "source": [ "1、 首先,安装pytorch和对应版本的torchvision,支持Pytorch1.5.1以上版本" ] }, { "cell_type": "code", "execution_count": null, "id": "9487def6-9441-435c-9da6-633cd09c517f", "metadata": { "tags": [] }, "outputs": [], "source": [ "# install pytorch and torch vision\n", "! conda install --yes pytorch==1.10.0 torchvision==0.11.0 -c pytorch" ] }, { "cell_type": "markdown", "id": "57b6b6a2-0106-43df-90a5-7b315e91cd64", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ "2、 获取torch和cuda版本,安装对应版本的mmcv和nvidia-dali" ] }, { "cell_type": "code", "execution_count": null, "id": "f93a5657-fdf1-467b-847d-e48fedb362ef", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "outputs": [], "source": [ "import torch\n", "import os\n", "os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')\n", "os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')\n", "!echo \"cuda version: $CUDA\"\n", "!echo \"pytorch version: $Torch\"" ] }, { "cell_type": "code", "execution_count": null, "id": "9c0d7f7d-bbdc-45ac-b738-e31e00d5eef4", "metadata": { "tags": [] }, "outputs": [], "source": [ "# install some python deps\n", "! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/${CUDA}/${Torch}/index.html\n", "! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl" ] }, { "cell_type": "markdown", "id": "2a987c04-e123-42b7-9d42-240e835aa410", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ "3、 安装EasyCV算法包" ] }, { "cell_type": "code", "execution_count": null, "id": "ee7590e1-c312-444e-af84-472459174d15", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "outputs": [], "source": [ "pip install pai-easycv" ] }, { "cell_type": "markdown", "id": "df0fc82c-1f2c-4869-92e6-76d7dfb34164", "metadata": { "tags": [] }, "source": [ "4、 简单验证" ] }, { "cell_type": "code", "execution_count": null, "id": "8b89751f-f0c0-4aad-aced-8f6c77a4a2aa", "metadata": {}, "outputs": [], "source": [ "from easycv.apis import *" ] }, { "cell_type": "markdown", "id": "8df2d210-4bcf-4b41-a981-b264c41e91a7", "metadata": { "tags": [] }, "source": [ "## Cifar10 分类" ] }, { "cell_type": "markdown", "id": "fce3f492-cfe0-4121-8f7d-877b31e3438b", "metadata": {}, "source": [ "下面示例介绍如何利用[cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)数据,使用ResNet50模型快速进行图像分类模型的训练评估、模型预测过程" ] }, { "cell_type": "markdown", "id": "6e769ad6-3ae0-4d5b-bf23-55673a239aa8", "metadata": {}, "source": [ "### 数据准备\n", "下载cifar10数据,解压到`data/cifar`目录, 目录结构如下\n", "\n", "```text\n", "data/cifar\n", "└── cifar-10-batches-py\n", " ├── batches.meta\n", " ├── data_batch_1\n", " ├── data_batch_2\n", " ├── data_batch_3\n", " ├── data_batch_4\n", " ├── data_batch_5\n", " ├── readme.html\n", " ├── read.py\n", " └── test_batch\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "62eeff3d-274d-4ff9-8123-8e11d97c0d38", "metadata": {}, "outputs": [], "source": [ "! mkdir -p data/cifar && wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/cifar-10-python.tar.gz && tar -zxf cifar-10-python.tar.gz -C data/cifar/" ] }, { "cell_type": "markdown", "id": "ed28f2be-1519-4d81-9d03-42891d1e9bc7", "metadata": {}, "source": [ "### 训练模型\n", "下载训练配置文件" ] }, { "cell_type": "code", "execution_count": null, "id": "bd8b6d19-652c-4a58-89d3-38394dbb9a20", "metadata": {}, "outputs": [], "source": [ "! rm -rf r50.py\n", "!wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/classification/cifar10/swintiny_b64_5e_jpg.py" ] }, { "cell_type": "markdown", "id": "3d2ed375-f5ef-475e-9dc8-3386d328b8b0", "metadata": {}, "source": [ "使用单卡gpu进行训练和验证集评估" ] }, { "cell_type": "code", "execution_count": null, "id": "9b6ec6da-d125-4ebe-b301-0d8cc8881512", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "! python -m easycv.tools.train swintiny_b64_5e_jpg.py --work_dir work_dirs/classification/cifar10/swin_tiny" ] }, { "cell_type": "markdown", "id": "82b04e66-ef62-4d2e-963f-34015a588786", "metadata": {}, "source": [ "### 导出模型\n", "\n", "模型训练完成,使用export命令导出模型进行推理,导出的模型包含推理时所需的预处理信息、后处理信息" ] }, { "cell_type": "code", "execution_count": null, "id": "1d4f7868-88aa-428c-9852-22e3eca45a07", "metadata": {}, "outputs": [], "source": [ "# 查看训练产生的pt文件\n", "! ls work_dirs/classification/cifar10/swin_tiny*" ] }, { "cell_type": "markdown", "id": "a5dc516f-fe35-4b1f-acbe-9ec96b72cc95", "metadata": {}, "source": [ "ClsEvaluator_neck_top1_best.pth 是训练过程中产生的acc最高的pth,导出该模型" ] }, { "cell_type": "code", "execution_count": null, "id": "67aa8262-834a-4afc-a6cd-fd882a181371", "metadata": {}, "outputs": [], "source": [ "! python -m easycv.tools.export swintiny_b64_5e_jpg.py work_dirs/classification/cifar10/swin_tiny/ClsEvaluator_neck_top1_best.pth work_dirs/classification/cifar10/swin_tiny/best_export.pth" ] }, { "cell_type": "markdown", "id": "72d78ef4-0d90-4f3a-a1d9-c532e17d065b", "metadata": {}, "source": [ "### 预测\n", "下载测试图片" ] }, { "cell_type": "code", "execution_count": null, "id": "732fa5b4-fe0f-4a9b-ba19-d73d3ebd1736", "metadata": {}, "outputs": [], "source": [ "! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/qince_data/predict/aeroplane_s_000004.png" ] }, { "cell_type": "code", "execution_count": null, "id": "cb3b24c2-7a19-4618-8bdf-42ebe1cdfeb3", "metadata": {}, "outputs": [], "source": [ "import cv2\n", "from easycv.predictors.classifier import TorchClassifier\n", "\n", "output_ckpt = 'work_dirs/classification/cifar10/swin_tiny/best_export.pth'\n", "tcls = TorchClassifier(output_ckpt, topk=1)\n", "\n", "img = cv2.imread('aeroplane_s_000004.png')\n", "# input image should be RGB order\n", "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", "output = tcls.predict([img])\n", "print(output)" ] }, { "cell_type": "code", "execution_count": null, "id": "5e617239-b034-4760-8c55-326cf3a6e475", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.12" } }, "nbformat": 4, "nbformat_minor": 5 }