mirror of https://github.com/alibaba/EasyCV.git
349 lines
8.7 KiB
Plaintext
349 lines
8.7 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c6c6e585-6d14-4b6a-8621-6395518b463e",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"# EasyCV图像分类-SwinTransformer\n",
|
||
"\n",
|
||
"本文将介绍如何使用EasyCV快速使用[Swin Transformer](https://arxiv.org/abs/2103.14030) 进行图像分类模型的训练、推理"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0512498b-a900-4969-9563-67029a1009e2",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## 运行环境要求\n",
|
||
"\n",
|
||
"PAI-Pytorch镜像 or 原生Pytorch1.5+以上环境 GPU机器, 内存32G以上"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0cafa45c-b04b-4471-b8f0-9c88ce50cc0e",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## 安装依赖包\n",
|
||
"\n",
|
||
"注: 在PAI-DSW docker中无需安装相关依赖,可跳过此部分 在本地notebook环境中执行\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c8280983-39b8-414d-a712-c40fd68dadfe",
|
||
"metadata": {},
|
||
"source": [
|
||
"1、 首先,安装pytorch和对应版本的torchvision,支持Pytorch1.5.1以上版本"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9487def6-9441-435c-9da6-633cd09c517f",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# install pytorch and torch vision\n",
|
||
"! conda install --yes pytorch==1.10.0 torchvision==0.11.0 -c pytorch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "57b6b6a2-0106-43df-90a5-7b315e91cd64",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"2、 获取torch和cuda版本,安装对应版本的mmcv和nvidia-dali"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "f93a5657-fdf1-467b-847d-e48fedb362ef",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"import os\n",
|
||
"os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')\n",
|
||
"os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')\n",
|
||
"!echo \"cuda version: $CUDA\"\n",
|
||
"!echo \"pytorch version: $Torch\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9c0d7f7d-bbdc-45ac-b738-e31e00d5eef4",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# install some python deps\n",
|
||
"! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/${CUDA}/${Torch}/index.html\n",
|
||
"! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2a987c04-e123-42b7-9d42-240e835aa410",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"3、 安装EasyCV算法包"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ee7590e1-c312-444e-af84-472459174d15",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"pip install pai-easycv"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "df0fc82c-1f2c-4869-92e6-76d7dfb34164",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"4、 简单验证"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "8b89751f-f0c0-4aad-aced-8f6c77a4a2aa",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from easycv.apis import *"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8df2d210-4bcf-4b41-a981-b264c41e91a7",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## Cifar10 分类"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "fce3f492-cfe0-4121-8f7d-877b31e3438b",
|
||
"metadata": {},
|
||
"source": [
|
||
"下面示例介绍如何利用[cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)数据,使用ResNet50模型快速进行图像分类模型的训练评估、模型预测过程"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6e769ad6-3ae0-4d5b-bf23-55673a239aa8",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 数据准备\n",
|
||
"下载cifar10数据,解压到`data/cifar`目录, 目录结构如下\n",
|
||
"\n",
|
||
"```text\n",
|
||
"data/cifar\n",
|
||
"└── cifar-10-batches-py\n",
|
||
" ├── batches.meta\n",
|
||
" ├── data_batch_1\n",
|
||
" ├── data_batch_2\n",
|
||
" ├── data_batch_3\n",
|
||
" ├── data_batch_4\n",
|
||
" ├── data_batch_5\n",
|
||
" ├── readme.html\n",
|
||
" ├── read.py\n",
|
||
" └── test_batch\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "62eeff3d-274d-4ff9-8123-8e11d97c0d38",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! mkdir -p data/cifar && wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/cifar-10-python.tar.gz && tar -zxf cifar-10-python.tar.gz -C data/cifar/"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ed28f2be-1519-4d81-9d03-42891d1e9bc7",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 训练模型\n",
|
||
"下载训练配置文件"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "bd8b6d19-652c-4a58-89d3-38394dbb9a20",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! rm -rf r50.py\n",
|
||
"!wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/classification/cifar10/swintiny_b64_5e_jpg.py"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "3d2ed375-f5ef-475e-9dc8-3386d328b8b0",
|
||
"metadata": {},
|
||
"source": [
|
||
"使用单卡gpu进行训练和验证集评估"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9b6ec6da-d125-4ebe-b301-0d8cc8881512",
|
||
"metadata": {
|
||
"scrolled": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"! python -m easycv.tools.train swintiny_b64_5e_jpg.py --work_dir work_dirs/classification/cifar10/swin_tiny"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "82b04e66-ef62-4d2e-963f-34015a588786",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 导出模型\n",
|
||
"\n",
|
||
"模型训练完成,使用export命令导出模型进行推理,导出的模型包含推理时所需的预处理信息、后处理信息"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "1d4f7868-88aa-428c-9852-22e3eca45a07",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 查看训练产生的pt文件\n",
|
||
"! ls work_dirs/classification/cifar10/swin_tiny*"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a5dc516f-fe35-4b1f-acbe-9ec96b72cc95",
|
||
"metadata": {},
|
||
"source": [
|
||
"ClsEvaluator_neck_top1_best.pth 是训练过程中产生的acc最高的pth,导出该模型"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "67aa8262-834a-4afc-a6cd-fd882a181371",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! python -m easycv.tools.export swintiny_b64_5e_jpg.py work_dirs/classification/cifar10/swin_tiny/ClsEvaluator_neck_top1_best.pth work_dirs/classification/cifar10/swin_tiny/best_export.pth"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "72d78ef4-0d90-4f3a-a1d9-c532e17d065b",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 预测\n",
|
||
"下载测试图片"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "732fa5b4-fe0f-4a9b-ba19-d73d3ebd1736",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/qince_data/predict/aeroplane_s_000004.png"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "cb3b24c2-7a19-4618-8bdf-42ebe1cdfeb3",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import cv2\n",
|
||
"from easycv.predictors.classifier import TorchClassifier\n",
|
||
"\n",
|
||
"output_ckpt = 'work_dirs/classification/cifar10/swin_tiny/best_export.pth'\n",
|
||
"tcls = TorchClassifier(output_ckpt, topk=1)\n",
|
||
"\n",
|
||
"img = cv2.imread('aeroplane_s_000004.png')\n",
|
||
"# input image should be RGB order\n",
|
||
"img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||
"output = tcls.predict([img])\n",
|
||
"print(output)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "5e617239-b034-4760-8c55-326cf3a6e475",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|