mirror of https://github.com/alibaba/EasyCV.git
429 lines
12 KiB
Plaintext
429 lines
12 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2f669eac-7d21-4647-9a78-3cdaaef0de88",
|
||
"metadata": {},
|
||
"source": [
|
||
"# EasyCV图像检测-YOLOX\n",
|
||
"本文将以[YOLOX](https://arxiv.org/abs/2107.08430)模型为例,介绍如何基于easyCV进行目标检测模型训练和预测"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ed0d2eba-5b7f-4bab-aec6-de9b29fc4c00",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## 运行环境要求\n",
|
||
"\n",
|
||
"PAI-Pytorch镜像 or 原生Pytorch1.5+以上环境 GPU机器, 内存32G以上"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b6507334-07ad-400a-9e10-1a8176108b7f",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## 安装依赖包\n",
|
||
"\n",
|
||
"注: 在PAI-DSW docker中无需安装相关依赖,可跳过此部分 在本地notebook环境中执行\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8c5788b0-524b-4cf7-9628-aae6ac44f462",
|
||
"metadata": {},
|
||
"source": [
|
||
"1、 首先,安装pytorch和对应版本的torchvision,支持Pytorch1.5.1以上版本"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "101d0cd7-64be-4234-b028-0c2f37714ac9",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# install pytorch and torch vision\n",
|
||
"! conda install --yes pytorch==1.10.0 torchvision==0.11.0 -c pytorch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "668ce563-9c7e-4c33-8600-be715f73c1d6",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"2、 获取torch和cuda版本,安装对应版本的mmcv和nvidia-dali"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9fa89683-efd7-459f-ae47-aeb304ead4fa",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"import os\n",
|
||
"os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')\n",
|
||
"os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')\n",
|
||
"!echo \"cuda version: $CUDA\"\n",
|
||
"!echo \"pytorch version: $Torch\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ee4a64b7-1571-43b4-b869-b7d5bbd414ad",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# install some python deps\n",
|
||
"! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/${CUDA}/${Torch}/index.html\n",
|
||
"! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "374fbe2e-c421-4aaa-a8fd-df625c060ef1",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"3、 安装EasyCV算法包"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "e234b8ce-7a66-44ff-a013-0029c61d7763",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"pip install pai-easycv"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0aa1172e-c0d3-4b43-9e60-e373a3498af0",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"4、 简单验证"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ad5ef3ff-1d2e-483e-afb2-1d37bc93e1c9",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from easycv.apis import *"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "75ce9cce-173b-4c00-af2e-ac55e0185038",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## 图像检测模型训练&预测"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "026ac411-d9e5-4cfc-b3bb-52baab3d0dd8",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 数据准备\n",
|
||
"\n",
|
||
"你可以下载[COCO2017](https://cocodataset.org/#download)数据,也可以使用我们提供了示例COCO数据"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "1095a6a5-776a-4df4-82e7-547f49edd0ae",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/small_coco_demo/small_coco_demo.tar.gz && tar -zxf small_coco_demo.tar.gz"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ae6ae669-2ba7-440d-a3fe-4b698d48b031",
|
||
"metadata": {},
|
||
"source": [
|
||
"重命名数据文件,使其和COCO数据格式完全一致"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "36eedef5-8a91-421e-aacb-b2f3448f8939",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"!mkdir -p data/ && mv small_coco_demo data/coco"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "47dbb359-4594-42ab-be8b-7d2be672388e",
|
||
"metadata": {},
|
||
"source": [
|
||
"data/coco格式如下\n",
|
||
"\n",
|
||
"```shell\n",
|
||
"data/coco/\n",
|
||
"├── annotations\n",
|
||
"│ ├── instances_train2017.json\n",
|
||
"│ └── instances_val2017.json\n",
|
||
"├── train2017\n",
|
||
"│ ├── 000000005802.jpg\n",
|
||
"│ ├── 000000060623.jpg\n",
|
||
"│ ├── 000000086408.jpg\n",
|
||
"│ ├── 000000118113.jpg\n",
|
||
"│ ├── 000000184613.jpg\n",
|
||
"│ ├── 000000193271.jpg\n",
|
||
"│ ├── 000000222564.jpg\n",
|
||
"│ ...\n",
|
||
"│ └── 000000574769.jpg\n",
|
||
"└── val2017\n",
|
||
" ├── 000000006818.jpg\n",
|
||
" ├── 000000017627.jpg\n",
|
||
" ├── 000000037777.jpg\n",
|
||
" ├── 000000087038.jpg\n",
|
||
" ├── 000000174482.jpg\n",
|
||
" ├── 000000181666.jpg\n",
|
||
" ├── 000000184791.jpg\n",
|
||
" ├── 000000252219.jpg\n",
|
||
" ...\n",
|
||
" └── 000000522713.jpg\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b9cc2e3f-faf7-4d44-841d-e3d32eaa09b8",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 模型训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b241a92d-eb9b-47d7-83d6-c11fef2292f9",
|
||
"metadata": {},
|
||
"source": [
|
||
"下载示例配置文件, 进行YOLOX-S模型训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "5dfc8556-ae61-4921-9748-d1169b12809a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! rm -rf yolox_s_8xb16_300e_coco.py\n",
|
||
"! wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/detection/yolox/yolox_s_8xb16_300e_coco.py"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4c63acf1-35ef-478f-949e-9fd45e932ae3",
|
||
"metadata": {},
|
||
"source": [
|
||
"为了适配小数据,我们对配置文件yolox_s_8xb16_300e_coco.py做如下字段的修改,减少训练epoch数目,加大打印日志的频率\n",
|
||
"\n",
|
||
"```python\n",
|
||
"\n",
|
||
"total_epochs = 3\n",
|
||
"\n",
|
||
"#optimizer.lr -> 0.0002\n",
|
||
"optimizer = dict(\n",
|
||
" type='SGD', lr=0.0002, momentum=0.9, weight_decay=5e-4, nesterov=True)\n",
|
||
"\n",
|
||
"# log_config.interval 1\n",
|
||
"log_config = dict(interval=1)\n",
|
||
"\n",
|
||
"```\n",
|
||
"\n",
|
||
"注意: 如果是使用COCO完整数据训练,为了保证效果,建议使用单机8卡进行训练; 如果要使用单卡训练,建议降低学习率`optimizer.lr`\n",
|
||
"\n",
|
||
"为了保证模型效果,我们在[预训练模型](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox_s_bs16_lr002/epoch_300.pth)基础上finetune, 执行如下命令启动训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b1fcfaef-dacd-4fd8-9c6a-01d8de66e179",
|
||
"metadata": {
|
||
"scrolled": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"!python -m easycv.tools.train yolox_s_8xb16_300e_coco.py --work_dir work_dir/detection/yolox/yolox_s_8xb16_300e_coco --load_from http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox_s_bs16_lr002/epoch_300.pth"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2d28eba0-3ee5-4348-b6b1-a85363abe411",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 导出模型\n",
|
||
"导出YOLOX 模型用于预测, 执行如下命令查看训练产生的模型文件"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "984314b2-b076-4c46-9216-850e8a13896b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! ls work_dir/detection/yolox/yolox_s_8xb16_300e_coco/*.pth"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9f5e6d36-e527-4789-ad7c-581364aa4e67",
|
||
"metadata": {},
|
||
"source": [
|
||
"在导出模型前,需要对配置文件进行修改,指定nms的得分阈值\n",
|
||
"\n",
|
||
"model.test_conf 0.01 -> 0.5\n",
|
||
"\n",
|
||
"```python\n",
|
||
"model = dict(\n",
|
||
" type='YOLOX',\n",
|
||
" num_classes=80,\n",
|
||
" model_type='s', # s m l x tiny nano\n",
|
||
" test_conf=0.5,\n",
|
||
" nms_thre=0.65)\n",
|
||
"```\n",
|
||
"\n",
|
||
"执行如下命令进行模型导出"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "f92e15f5-9148-4ac5-b828-962ecfff137e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! cp yolox_s_8xb16_300e_coco.py yolox_s_8xb16_300e_coco_export.py && sed -i 's#test_conf=0.01#test_conf=0.5#g' yolox_s_8xb16_300e_coco_export.py\n",
|
||
"!python -m easycv.tools.export yolox_s_8xb16_300e_coco_export.py work_dir/detection/yolox/yolox_s_8xb16_300e_coco/epoch_30.pth work_dir/detection/yolox/yolox_s_8xb16_300e_coco/yolox_export.pth"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "78251956-7a06-4b05-a390-259d8cf1a3be",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 模型预测\n",
|
||
"下载测试图片"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0409bbcd-8ea0-4341-8cb2-d7d728f4fb31",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"!wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/small_coco_demo/val2017/000000017627.jpg"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9c201d92-159a-46db-b152-573e3aad36c2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import cv2\n",
|
||
"from easycv.predictors import TorchYoloXPredictor\n",
|
||
"\n",
|
||
"output_ckpt = 'work_dir/detection/yolox/yolox_s_8xb16_300e_coco/yolox_export.pth'\n",
|
||
"detector = TorchYoloXPredictor(output_ckpt)\n",
|
||
"\n",
|
||
"img = cv2.imread('000000017627.jpg')\n",
|
||
"img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||
"output = detector.predict([img])\n",
|
||
"print(output)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a44b291c-bcd3-4f65-b2ab-0359dfc6e246",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# view detection results\n",
|
||
"\n",
|
||
"%matplotlib inline\n",
|
||
"from matplotlib import pyplot as plt\n",
|
||
"image = img.copy()\n",
|
||
"for box, cls_name in zip(output[0]['detection_boxes'], output[0]['detection_class_names']):\n",
|
||
" # box is [x1,y1,x2,y2]\n",
|
||
" box = [int(b) for b in box]\n",
|
||
" image = cv2.rectangle(image, tuple(box[:2]), tuple(box[2:4]), (0,255,0), 2)\n",
|
||
" cv2.putText(image, cls_name, (box[0], box[1]-5), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,0,255), 2)\n",
|
||
"plt.imshow(image)\n",
|
||
"plt.show()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|