mirror of https://github.com/alibaba/EasyCV.git
392 lines
10 KiB
Plaintext
392 lines
10 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a39e979f-2534-41b2-98c1-57b030ef1228",
|
||
"metadata": {},
|
||
"source": [
|
||
"# EasyCV图像自监督训练-DINO\n",
|
||
"\n",
|
||
"本文讲介绍如何利用EasyCV使用自监督算法[DINO](https://arxiv.org/abs/2104.14294)进行图像自监督模型的训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4343b52b-12af-484c-8fe9-d6be0c67587f",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## 运行环境要求\n",
|
||
"\n",
|
||
"PAI-Pytorch镜像 or 原生Pytorch1.5+以上环境 GPU机器, 内存32G以上"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9d5e07d2-eef6-44ce-86c9-da432b4f5e5d",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## 安装依赖包\n",
|
||
"\n",
|
||
"注: 在PAI-DSW docker中无需安装相关依赖,可跳过此部分 在本地notebook环境中执行\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "74565abe-f1d4-47b7-99f5-1e3ec9fda250",
|
||
"metadata": {},
|
||
"source": [
|
||
"1、 首先,安装pytorch和对应版本的torchvision,支持Pytorch1.5.1以上版本"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "32ba9b44-2ec8-4f3e-9f64-df4c21c5b474",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# install pytorch and torch vision\n",
|
||
"! conda install --yes pytorch==1.10.0 torchvision==0.11.0 -c pytorch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "705a64b9-920e-4f13-9109-eea263d988ca",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"2、 获取torch和cuda版本,安装对应版本的mmcv和nvidia-dali"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "49490272-9866-475b-8be6-8bb15efc8f02",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"import os\n",
|
||
"os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')\n",
|
||
"os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')\n",
|
||
"!echo \"cuda version: $CUDA\"\n",
|
||
"!echo \"pytorch version: $Torch\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "26e987b4-9a67-4c6b-880b-2a37bf6b4441",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# install some python deps\n",
|
||
"! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/${CUDA}/${Torch}/index.html\n",
|
||
"! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "835e09c5-c89f-48a0-92a0-8b736d900ff7",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"3、 安装EasyCV算法包"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "805b2f81-cd75-4693-b03d-ba63768a48f2",
|
||
"metadata": {
|
||
"jp-MarkdownHeadingCollapsed": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"pip install pai-easycv"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1c72ede0-31cd-40ae-9ccb-799875a83e3a",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"4、 简单验证"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "26170872-8caa-4bfc-85cf-c10add80f35e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from easycv.apis import *"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "97cc7722-1a6d-4dea-a815-be76029a7faf",
|
||
"metadata": {},
|
||
"source": [
|
||
"## EasyCV自监督训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c640853a-c698-4eba-833f-b11f562bdf99",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 数据准备\n",
|
||
"\n",
|
||
"自监督训练只需要提供无标注图片即可进行, 你可以下载[ImageNet](http://www.image-net.org/download-images) 数据,或者使用你自己的图片数据。需要提供一个包含若干图片的文件夹路径`p`,以及一个文件列表,文件列表中是每个图片相对图片目录`p`的路径\n",
|
||
"\n",
|
||
"图片文件夹结构示例如下, 文件夹路径为`./images`\n",
|
||
"\n",
|
||
"```shell\n",
|
||
"images/\n",
|
||
"├── 0001.jpg\n",
|
||
"├── 0002.jpg\n",
|
||
"├── 0003.jpg\n",
|
||
"|...\n",
|
||
"└── 9999.jpg\n",
|
||
"```\n",
|
||
"\n",
|
||
"文件列表内容示例如下\n",
|
||
"```text\n",
|
||
"0001.jpg\n",
|
||
"0002.jpg\n",
|
||
"0003.jpg\n",
|
||
"...\n",
|
||
"9999.jpg\n",
|
||
"```\n",
|
||
"\n",
|
||
"为了快速走通流程,我们也提供了一个小的示例数据集,执行如下命令下载解压"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "4feca89e-8d09-4393-a24c-50038689605a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/imagenet_raw_demo/imagenet_raw_demo.tar.gz && tar -zxf imagenet_raw_demo.tar.gz"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "88471c98-bcc2-4e2c-8783-a92cdfc42d4f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 重命名文件夹\n",
|
||
"! mv imagenet_raw_demo imagenet_raw"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8d951808-6488-4d95-85a3-f917b9dfae29",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 模型训练\n",
|
||
"\n",
|
||
"这个Demo中我们采用[mocov2](https://arxiv.org/abs/2003.04297)自监督算法训练ResNet50 主干网络, 下载示例配置文件"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "6f0045e5-7f2f-4c48-80be-277f002b548f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/selfsup/dino/dino_deit_small_p16_8xb32_100e_jpg.py"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "20a6b241-6d4e-4526-a533-b100a3cebb7f",
|
||
"metadata": {},
|
||
"source": [
|
||
"为了缩短训练时间,打开配置文件 `mocov2_rn50_8xb32_200e_jpg.py`,修改`total_epoch`参数为20, 每隔1次迭代打印一次日志。\n",
|
||
"\n",
|
||
"```python\n",
|
||
"# runtime settings\n",
|
||
"total_epochs = 20\n",
|
||
"\n",
|
||
"# log config\n",
|
||
"log_config=dict(interval=1)\n",
|
||
"```\n",
|
||
"\n",
|
||
"正式训练时,建议使用`单机8卡`配合该配置文件使用,如果要使用单机单卡,建议调小`optimizer.lr`初始学习率\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "af199c86-a0a5-4fac-96c6-8f5a357a6271",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 查看easycv安装位置\n",
|
||
"import easycv\n",
|
||
"print(easycv.__file__)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d4aa1091-2c31-429a-b516-961b03356a22",
|
||
"metadata": {
|
||
"scrolled": true,
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"!python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \\\n",
|
||
"/home/pai/lib/python3.6/site-packages/easycv/tools/train.py dino_deit_small_p16_8xb32_100e_jpg.py --work_dir work_dir/selfsup/jpg/dino_deit_small_p16 --launcher pytorch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bc8fb5d9-5afa-4bd2-9bdf-949c45034279",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 使用自监督模型进行特征抽取"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a881a8d3-f767-4b70-a734-fca88fb75298",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"#### 模型导出"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "efdd36dd-5d58-42e0-bf5c-c9b2f1e41d88",
|
||
"metadata": {},
|
||
"source": [
|
||
"模型导出会对自监督模型信息裁剪,保留特征抽取必要的backbone和head"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "6b566b5c-818d-42a8-814f-5288c1814ab0",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 查看训练产生的模型文件\n",
|
||
"!ls work_dir/selfsup/jpg/dino_deit_small_p16/*.pth"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "f6e7d1fa-1620-4268-a6f3-fb75b683a011",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! python -m easycv.tools.export dino_deit_small_p16_8xb32_100e_jpg.py work_dir/selfsup/jpg/dino_deit_small_p16/epoch_10.pth work_dir/selfsup/jpg/dino_deit_small_p16/epoch_10_export.pth"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "fd80ddd2-33a6-4972-88fe-ebb442c6c629",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#下载测试图片\n",
|
||
"! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/product_detection/248347732153_1040.jpg"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "84150726-c90c-4ab4-9a5d-c46d223d403d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import cv2\n",
|
||
"from easycv.predictors.feature_extractor import TorchFeatureExtractor\n",
|
||
"\n",
|
||
"# 修改output_ckpt指向\n",
|
||
"output_ckpt = 'work_dir/selfsup/jpg/dino_deit_small_p16/epoch_10_export.pth'\n",
|
||
"fe = TorchFeatureExtractor(output_ckpt)\n",
|
||
"\n",
|
||
"img = cv2.imread('248347732153_1040.jpg')\n",
|
||
"img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||
"feature = fe.predict([img])\n",
|
||
"print(feature[0]['feature'].shape)\n",
|
||
"print(feature[0])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6dee5721-ea78-4177-b74c-f00006e9a542",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 自监督预训练+ 图像分类finetune\n",
|
||
"参考EasyCV图像分类的demo, 在训练时加上--load_from 参数,使用自监督预训练的模型权重, 注意这里不需要使用"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b1e6ae4f-7a1c-416a-ae65-3e756c1d8551",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"! python -m easycv.tools.train r50.py --work_dir work_dirs/classification/cifar10/dino_deit_small_p16 --load_from work_dir/selfsup/jpg/dino_deit_small_p16/epoch_10.pth"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|