EasyCV/docs/source/tutorials/EasyCV图像自监督训练-MAE.ipynb

362 lines
10 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"id": "69debbc2-30e8-4390-805a-cb6ebad35f3f",
"metadata": {
"tags": []
},
"source": [
"## EasyCV图像自监督训练-MAE\n",
"本文将介绍如何利用EasyCV使用自监督算法[MAE](https://arxiv.org/pdf/2111.06377.pdf)进行图像自监督模型的训练\n",
" \n"
]
},
{
"cell_type": "markdown",
"id": "b9ef9614-cb63-487f-bafb-c62c90ae607b",
"metadata": {},
"source": [
"## 运行环境要求\n",
"\n",
"PAI-Pytorch镜像 or 原生Pytorch1.5+以上环境 GPU机器 内存32G以上"
]
},
{
"cell_type": "markdown",
"id": "f855f736-c08a-44c2-b1c1-33a8e7043864",
"metadata": {
"tags": []
},
"source": [
"## 安装依赖包\n",
"\n",
"注: 在PAI-DSW docker中无需安装相关依赖可跳过此部分 在本地notebook环境中执行\n"
]
},
{
"cell_type": "markdown",
"id": "43e60ee5-f52f-4045-bc69-4a123731a5a5",
"metadata": {},
"source": [
"1、 首先安装pytorch和对应版本的torchvision支持Pytorch1.5.1以上版本"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f97b2295-4d67-4a83-8637-ecde48fa3001",
"metadata": {},
"outputs": [],
"source": [
"# install pytorch and torch vision\n",
"! conda install --yes pytorch==1.10.0 torchvision==0.11.0 -c pytorch"
]
},
{
"cell_type": "markdown",
"id": "ff3ed9f8-703d-416e-8b70-10a8ce029e8d",
"metadata": {},
"source": [
"2、获取torch和cuda版本安装对应版本的mmcv和nvidia-dali"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c95571d5-3065-40b3-ad53-dc3063db8604",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')\n",
"os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')\n",
"!echo \"cuda version: $CUDA\"\n",
"!echo \"pytorch version: $Torch\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce3dca16-9baf-42f8-94c4-5ea7f087d61e",
"metadata": {},
"outputs": [],
"source": [
"# install some python deps\n",
"! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/${CUDA}/${Torch}/index.html\n",
"! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl"
]
},
{
"cell_type": "markdown",
"id": "3dee3a99-191a-4515-97dd-80841db43775",
"metadata": {},
"source": [
"3、 安装EasyCV算法包"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f03f9e9-6029-4918-b8dc-533db9c7fae3",
"metadata": {},
"outputs": [],
"source": [
"pip install pai-easycv"
]
},
{
"cell_type": "markdown",
"id": "3483904d-91dd-44c4-a486-246aa38d4124",
"metadata": {},
"source": [
"4、 简单验证"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d382373-76a0-4fbc-b09f-5d082aab5104",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from easycv.apis import *"
]
},
{
"cell_type": "markdown",
"id": "29bb3d55-d00b-453b-9522-c686260e325c",
"metadata": {},
"source": [
"## 数据准备\n",
"\n",
"自监督训练只需要提供无标注图片即可进行, 你可以下载[ImageNet](http://www.image-net.org/download-images) 数据,或者使用你自己的图片数据。需要提供一个包含若干图片的文件夹路径`p`,以及一个文件列表,文件列表中是每个图片相对图片目录`p`的路径\n",
"\n",
"图片文件夹结构示例如下, 文件夹路径为`./images`\n",
"\n",
"```shell\n",
"images/\n",
"├── 0001.jpg\n",
"├── 0002.jpg\n",
"├── 0003.jpg\n",
"|...\n",
"└── 9999.jpg\n",
"```\n",
"\n",
"文件列表内容示例如下\n",
"```text\n",
"0001.jpg\n",
"0002.jpg\n",
"0003.jpg\n",
"...\n",
"9999.jpg\n",
"```\n",
"\n",
"为了快速走通流程,我们也提供了一个小的示例数据集,执行如下命令下载解压"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9492b324-17d9-4963-b7dd-a49f684c54c3",
"metadata": {},
"outputs": [],
"source": [
"! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/imagenet_raw_demo/imagenet_raw_demo.tar.gz && tar -zxf imagenet_raw_demo.tar.gz"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fdc9dc3-0886-493f-be5b-f858eff72164",
"metadata": {},
"outputs": [],
"source": [
"# 重命名文件夹\n",
"! mv imagenet_raw_demo imagenet_raw"
]
},
{
"cell_type": "markdown",
"id": "c9bd7101-3072-417c-ac54-5d7d254b31b4",
"metadata": {},
"source": [
"## 模型训练\n",
"\n",
"这个Demo中我们采用[MAE](https://arxiv.org/pdf/2111.06377.pdf)自监督算法训练vit-base主干网络 下载示例配置文件"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab669385-8525-4694-8388-148ba1c2753a",
"metadata": {},
"outputs": [],
"source": [
"! rm -rf mae_vit_base_patch16_8xb64_1600e.py\n",
"! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/doc/easycv/configs/selfsup/mae/mae_vit_base_patch16_8xb64_1600e.py"
]
},
{
"cell_type": "markdown",
"id": "2890d267-6b95-47e6-9b51-83402446fa7f",
"metadata": {},
"source": [
"为了缩短训练时间,打开配置文件 `mae_vit_base_patch16_8xb64_1600e.py`,修改`total_epoch`参数为5 每隔1次迭代打印一次日志。\n",
"\n",
"```python\n",
"# runtime settings\n",
"total_epochs = 5\n",
"\n",
"# log config\n",
"log_config=dict(interval=1)\n",
"```\n",
"\n",
"正式训练时,建议使用`多机8卡`或`单机8卡`配合该配置文件使用修改update_interval参数进行梯度累积确保有效batch_size一致"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aaf88f46-a578-4dfc-a33f-afa0357f734a",
"metadata": {},
"outputs": [],
"source": [
"# 查看easycv安装位置\n",
"import easycv\n",
"print(easycv.__file__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d01f27fb-b373-4e88-9ac2-01bc35a2f5f2",
"metadata": {},
"outputs": [],
"source": [
"!python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \\\n",
"/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mae_vit_base_patch16_8xb64_1600e.py --work_dir work_dir/selfsup/jpg/mae --launcher pytorch"
]
},
{
"cell_type": "markdown",
"id": "e3aa0e8b-330c-45b2-be84-43f90260588c",
"metadata": {},
"source": [
"## 模型导出\n",
"对模型的字段进行修改,以便用于fintune任务"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db2d0dbc-b04e-4244-803b-c03561006c5c",
"metadata": {},
"outputs": [],
"source": [
"import torch \n",
"weight_path = 'work_dir/selfsup/jpg/mae/epoch_5.pth'\n",
"state_dict = torch.load(weight_path)['state_dict']\n",
"state_dict_out = {}\n",
"for key in state_dict:\n",
" state_dict_out[key.replace('encoder.','')] = state_dict[key]\n",
"torch.save(state_dict_out,weight_path)"
]
},
{
"cell_type": "markdown",
"id": "aad27bb4-e602-4d66-bba5-e379926559f3",
"metadata": {},
"source": [
"## 使用自监督模型进行图像分类fintune\n",
"下载分类任务示例配置文件"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78dc417a-2c16-4d7b-a239-e150283e9cac",
"metadata": {},
"outputs": [],
"source": [
"! rm -rf mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py\n",
"! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/doc/easycv/configs/selfsup/mae/mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py"
]
},
{
"cell_type": "markdown",
"id": "8db1ffd1-6788-41f4-a422-2709903290ad",
"metadata": {},
"source": [
"修改配置文件 `mae_vit_base_patch16_8xb64_1600e.py`,修改`model.pretrained`参数为`work_dir/selfsup/jpg/mae/epoch_5.pth`,为了缩短训练时间,修改`total_epoch`参数为5每隔1次迭代打印一次日志。\n",
"\n",
"```python\n",
"# runtime settings\n",
"total_epochs = 5\n",
"\n",
"# log config\n",
"log_config=dict(interval=1)\n",
"\n",
"# \n",
"pretrained='work_dir/selfsup/jpg/mae/epoch_5.pth'\n",
"```\n",
"\n",
"正式训练时,建议使用`多机8卡`或`单机8卡`配合该配置文件使用修改update_interval参数进行梯度累积确保有效batch_size一致"
]
},
{
"cell_type": "markdown",
"id": "68e72360-c054-4835-9643-ad0698a97d65",
"metadata": {},
"source": [
"### 分类模型训练\n",
"这里提供了单卡进行训练和验证集评估的命令"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "defc60df-e8ca-4eac-9923-2cfb1f9fa7f1",
"metadata": {},
"outputs": [],
"source": [
"!python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \\\n",
"/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py --work_dir work_dir/selfsup/jpg/mae --launcher pytorch"
]
},
{
"cell_type": "markdown",
"id": "fc43e194-d8e7-4796-af3a-1f64663b9744",
"metadata": {},
"source": [
"### 预测\n",
"参考EasyCV图像分类的demo对训练好的模型导出并预测"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}