EasyCV/docs/source/tutorials/EasyCV图像分类swinTransformer.i...

349 lines
8.7 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"id": "c6c6e585-6d14-4b6a-8621-6395518b463e",
"metadata": {
"tags": []
},
"source": [
"# EasyCV图像分类-SwinTransformer\n",
"\n",
"本文将介绍如何使用EasyCV快速使用[Swin Transformer](https://arxiv.org/abs/2103.14030) 进行图像分类模型的训练、推理"
]
},
{
"cell_type": "markdown",
"id": "0512498b-a900-4969-9563-67029a1009e2",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
"## 运行环境要求\n",
"\n",
"PAI-Pytorch镜像 or 原生Pytorch1.5+以上环境 GPU机器 内存32G以上"
]
},
{
"cell_type": "markdown",
"id": "0cafa45c-b04b-4471-b8f0-9c88ce50cc0e",
"metadata": {
"tags": []
},
"source": [
"## 安装依赖包\n",
"\n",
"注: 在PAI-DSW docker中无需安装相关依赖可跳过此部分 在本地notebook环境中执行\n"
]
},
{
"cell_type": "markdown",
"id": "c8280983-39b8-414d-a712-c40fd68dadfe",
"metadata": {},
"source": [
"1、 首先安装pytorch和对应版本的torchvision支持Pytorch1.5.1以上版本"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9487def6-9441-435c-9da6-633cd09c517f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# install pytorch and torch vision\n",
"! conda install --yes pytorch==1.10.0 torchvision==0.11.0 -c pytorch"
]
},
{
"cell_type": "markdown",
"id": "57b6b6a2-0106-43df-90a5-7b315e91cd64",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
"2、 获取torch和cuda版本安装对应版本的mmcv和nvidia-dali"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f93a5657-fdf1-467b-847d-e48fedb362ef",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')\n",
"os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')\n",
"!echo \"cuda version: $CUDA\"\n",
"!echo \"pytorch version: $Torch\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c0d7f7d-bbdc-45ac-b738-e31e00d5eef4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# install some python deps\n",
"! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/${CUDA}/${Torch}/index.html\n",
"! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl"
]
},
{
"cell_type": "markdown",
"id": "2a987c04-e123-42b7-9d42-240e835aa410",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
"3、 安装EasyCV算法包"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee7590e1-c312-444e-af84-472459174d15",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"outputs": [],
"source": [
"pip install pai-easycv"
]
},
{
"cell_type": "markdown",
"id": "df0fc82c-1f2c-4869-92e6-76d7dfb34164",
"metadata": {
"tags": []
},
"source": [
"4、 简单验证"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b89751f-f0c0-4aad-aced-8f6c77a4a2aa",
"metadata": {},
"outputs": [],
"source": [
"from easycv.apis import *"
]
},
{
"cell_type": "markdown",
"id": "8df2d210-4bcf-4b41-a981-b264c41e91a7",
"metadata": {
"tags": []
},
"source": [
"## Cifar10 分类"
]
},
{
"cell_type": "markdown",
"id": "fce3f492-cfe0-4121-8f7d-877b31e3438b",
"metadata": {},
"source": [
"下面示例介绍如何利用[cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)数据使用ResNet50模型快速进行图像分类模型的训练评估、模型预测过程"
]
},
{
"cell_type": "markdown",
"id": "6e769ad6-3ae0-4d5b-bf23-55673a239aa8",
"metadata": {},
"source": [
"### 数据准备\n",
"下载cifar10数据解压到`data/cifar`目录, 目录结构如下\n",
"\n",
"```text\n",
"data/cifar\n",
"└── cifar-10-batches-py\n",
" ├── batches.meta\n",
" ├── data_batch_1\n",
" ├── data_batch_2\n",
" ├── data_batch_3\n",
" ├── data_batch_4\n",
" ├── data_batch_5\n",
" ├── readme.html\n",
" ├── read.py\n",
" └── test_batch\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62eeff3d-274d-4ff9-8123-8e11d97c0d38",
"metadata": {},
"outputs": [],
"source": [
"! mkdir -p data/cifar && wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/cifar-10-python.tar.gz && tar -zxf cifar-10-python.tar.gz -C data/cifar/"
]
},
{
"cell_type": "markdown",
"id": "ed28f2be-1519-4d81-9d03-42891d1e9bc7",
"metadata": {},
"source": [
"### 训练模型\n",
"下载训练配置文件"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd8b6d19-652c-4a58-89d3-38394dbb9a20",
"metadata": {},
"outputs": [],
"source": [
"! rm -rf r50.py\n",
"!wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/classification/cifar10/swintiny_b64_5e_jpg.py"
]
},
{
"cell_type": "markdown",
"id": "3d2ed375-f5ef-475e-9dc8-3386d328b8b0",
"metadata": {},
"source": [
"使用单卡gpu进行训练和验证集评估"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b6ec6da-d125-4ebe-b301-0d8cc8881512",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"! python -m easycv.tools.train swintiny_b64_5e_jpg.py --work_dir work_dirs/classification/cifar10/swin_tiny"
]
},
{
"cell_type": "markdown",
"id": "82b04e66-ef62-4d2e-963f-34015a588786",
"metadata": {},
"source": [
"### 导出模型\n",
"\n",
"模型训练完成使用export命令导出模型进行推理导出的模型包含推理时所需的预处理信息、后处理信息"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d4f7868-88aa-428c-9852-22e3eca45a07",
"metadata": {},
"outputs": [],
"source": [
"# 查看训练产生的pt文件\n",
"! ls work_dirs/classification/cifar10/swin_tiny*"
]
},
{
"cell_type": "markdown",
"id": "a5dc516f-fe35-4b1f-acbe-9ec96b72cc95",
"metadata": {},
"source": [
"ClsEvaluator_neck_top1_best.pth 是训练过程中产生的acc最高的pth,导出该模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67aa8262-834a-4afc-a6cd-fd882a181371",
"metadata": {},
"outputs": [],
"source": [
"! python -m easycv.tools.export swintiny_b64_5e_jpg.py work_dirs/classification/cifar10/swin_tiny/ClsEvaluator_neck_top1_best.pth work_dirs/classification/cifar10/swin_tiny/best_export.pth"
]
},
{
"cell_type": "markdown",
"id": "72d78ef4-0d90-4f3a-a1d9-c532e17d065b",
"metadata": {},
"source": [
"### 预测\n",
"下载测试图片"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "732fa5b4-fe0f-4a9b-ba19-d73d3ebd1736",
"metadata": {},
"outputs": [],
"source": [
"! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/qince_data/predict/aeroplane_s_000004.png"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb3b24c2-7a19-4618-8bdf-42ebe1cdfeb3",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"from easycv.predictors.classifier import TorchClassifier\n",
"\n",
"output_ckpt = 'work_dirs/classification/cifar10/swin_tiny/best_export.pth'\n",
"tcls = TorchClassifier(output_ckpt, topk=1)\n",
"\n",
"img = cv2.imread('aeroplane_s_000004.png')\n",
"# input image should be RGB order\n",
"img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
"output = tcls.predict([img])\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e617239-b034-4760-8c55-326cf3a6e475",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}