{ "cells": [ { "cell_type": "markdown", "id": "a39e979f-2534-41b2-98c1-57b030ef1228", "metadata": {}, "source": [ "# EasyCV图像自监督训练-DINO\n", "\n", "本文讲介绍如何利用EasyCV使用自监督算法[DINO](https://arxiv.org/abs/2104.14294)进行图像自监督模型的训练" ] }, { "cell_type": "markdown", "id": "4343b52b-12af-484c-8fe9-d6be0c67587f", "metadata": { "tags": [] }, "source": [ "## 运行环境要求\n", "\n", "PAI-Pytorch镜像 or 原生Pytorch1.5+以上环境 GPU机器, 内存32G以上" ] }, { "cell_type": "markdown", "id": "9d5e07d2-eef6-44ce-86c9-da432b4f5e5d", "metadata": { "tags": [] }, "source": [ "## 安装依赖包\n", "\n", "注: 在PAI-DSW docker中无需安装相关依赖,可跳过此部分 在本地notebook环境中执行\n" ] }, { "cell_type": "markdown", "id": "74565abe-f1d4-47b7-99f5-1e3ec9fda250", "metadata": {}, "source": [ "1、 首先,安装pytorch和对应版本的torchvision,支持Pytorch1.5.1以上版本" ] }, { "cell_type": "code", "execution_count": null, "id": "32ba9b44-2ec8-4f3e-9f64-df4c21c5b474", "metadata": { "tags": [] }, "outputs": [], "source": [ "# install pytorch and torch vision\n", "! conda install --yes pytorch==1.10.0 torchvision==0.11.0 -c pytorch" ] }, { "cell_type": "markdown", "id": "705a64b9-920e-4f13-9109-eea263d988ca", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ "2、 获取torch和cuda版本,安装对应版本的mmcv和nvidia-dali" ] }, { "cell_type": "code", "execution_count": null, "id": "49490272-9866-475b-8be6-8bb15efc8f02", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "outputs": [], "source": [ "import torch\n", "import os\n", "os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')\n", "os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')\n", "!echo \"cuda version: $CUDA\"\n", "!echo \"pytorch version: $Torch\"" ] }, { "cell_type": "code", "execution_count": null, "id": "26e987b4-9a67-4c6b-880b-2a37bf6b4441", "metadata": { "tags": [] }, "outputs": [], "source": [ "# install some python deps\n", "! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/${CUDA}/${Torch}/index.html\n", "! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl" ] }, { "cell_type": "markdown", "id": "835e09c5-c89f-48a0-92a0-8b736d900ff7", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ "3、 安装EasyCV算法包" ] }, { "cell_type": "code", "execution_count": null, "id": "805b2f81-cd75-4693-b03d-ba63768a48f2", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "outputs": [], "source": [ "pip install pai-easycv" ] }, { "cell_type": "markdown", "id": "1c72ede0-31cd-40ae-9ccb-799875a83e3a", "metadata": { "tags": [] }, "source": [ "4、 简单验证" ] }, { "cell_type": "code", "execution_count": null, "id": "26170872-8caa-4bfc-85cf-c10add80f35e", "metadata": {}, "outputs": [], "source": [ "from easycv.apis import *" ] }, { "cell_type": "markdown", "id": "97cc7722-1a6d-4dea-a815-be76029a7faf", "metadata": {}, "source": [ "## EasyCV自监督训练" ] }, { "cell_type": "markdown", "id": "c640853a-c698-4eba-833f-b11f562bdf99", "metadata": {}, "source": [ "### 数据准备\n", "\n", "自监督训练只需要提供无标注图片即可进行, 你可以下载[ImageNet](http://www.image-net.org/download-images) 数据,或者使用你自己的图片数据。需要提供一个包含若干图片的文件夹路径`p`,以及一个文件列表,文件列表中是每个图片相对图片目录`p`的路径\n", "\n", "图片文件夹结构示例如下, 文件夹路径为`./images`\n", "\n", "```shell\n", "images/\n", "├── 0001.jpg\n", "├── 0002.jpg\n", "├── 0003.jpg\n", "|...\n", "└── 9999.jpg\n", "```\n", "\n", "文件列表内容示例如下\n", "```text\n", "0001.jpg\n", "0002.jpg\n", "0003.jpg\n", "...\n", "9999.jpg\n", "```\n", "\n", "为了快速走通流程,我们也提供了一个小的示例数据集,执行如下命令下载解压" ] }, { "cell_type": "code", "execution_count": null, "id": "4feca89e-8d09-4393-a24c-50038689605a", "metadata": {}, "outputs": [], "source": [ "! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/imagenet_raw_demo/imagenet_raw_demo.tar.gz && tar -zxf imagenet_raw_demo.tar.gz" ] }, { "cell_type": "code", "execution_count": null, "id": "88471c98-bcc2-4e2c-8783-a92cdfc42d4f", "metadata": {}, "outputs": [], "source": [ "# 重命名文件夹\n", "! mv imagenet_raw_demo imagenet_raw" ] }, { "cell_type": "markdown", "id": "8d951808-6488-4d95-85a3-f917b9dfae29", "metadata": {}, "source": [ "### 模型训练\n", "\n", "这个Demo中我们采用[mocov2](https://arxiv.org/abs/2003.04297)自监督算法训练ResNet50 主干网络, 下载示例配置文件" ] }, { "cell_type": "code", "execution_count": null, "id": "6f0045e5-7f2f-4c48-80be-277f002b548f", "metadata": {}, "outputs": [], "source": [ "! wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/selfsup/dino/dino_deit_small_p16_8xb32_100e_jpg.py" ] }, { "cell_type": "markdown", "id": "20a6b241-6d4e-4526-a533-b100a3cebb7f", "metadata": {}, "source": [ "为了缩短训练时间,打开配置文件 `mocov2_rn50_8xb32_200e_jpg.py`,修改`total_epoch`参数为20, 每隔1次迭代打印一次日志。\n", "\n", "```python\n", "# runtime settings\n", "total_epochs = 20\n", "\n", "# log config\n", "log_config=dict(interval=1)\n", "```\n", "\n", "正式训练时,建议使用`单机8卡`配合该配置文件使用,如果要使用单机单卡,建议调小`optimizer.lr`初始学习率\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "af199c86-a0a5-4fac-96c6-8f5a357a6271", "metadata": {}, "outputs": [], "source": [ "# 查看easycv安装位置\n", "import easycv\n", "print(easycv.__file__)" ] }, { "cell_type": "code", "execution_count": null, "id": "d4aa1091-2c31-429a-b516-961b03356a22", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "!python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \\\n", "/home/pai/lib/python3.6/site-packages/easycv/tools/train.py dino_deit_small_p16_8xb32_100e_jpg.py --work_dir work_dir/selfsup/jpg/dino_deit_small_p16 --launcher pytorch" ] }, { "cell_type": "markdown", "id": "bc8fb5d9-5afa-4bd2-9bdf-949c45034279", "metadata": {}, "source": [ "### 使用自监督模型进行特征抽取" ] }, { "cell_type": "markdown", "id": "a881a8d3-f767-4b70-a734-fca88fb75298", "metadata": { "tags": [] }, "source": [ "#### 模型导出" ] }, { "cell_type": "markdown", "id": "efdd36dd-5d58-42e0-bf5c-c9b2f1e41d88", "metadata": {}, "source": [ "模型导出会对自监督模型信息裁剪,保留特征抽取必要的backbone和head" ] }, { "cell_type": "code", "execution_count": null, "id": "6b566b5c-818d-42a8-814f-5288c1814ab0", "metadata": {}, "outputs": [], "source": [ "# 查看训练产生的模型文件\n", "!ls work_dir/selfsup/jpg/dino_deit_small_p16/*.pth" ] }, { "cell_type": "code", "execution_count": null, "id": "f6e7d1fa-1620-4268-a6f3-fb75b683a011", "metadata": {}, "outputs": [], "source": [ "! python -m easycv.tools.export dino_deit_small_p16_8xb32_100e_jpg.py work_dir/selfsup/jpg/dino_deit_small_p16/epoch_10.pth work_dir/selfsup/jpg/dino_deit_small_p16/epoch_10_export.pth" ] }, { "cell_type": "code", "execution_count": null, "id": "fd80ddd2-33a6-4972-88fe-ebb442c6c629", "metadata": {}, "outputs": [], "source": [ "#下载测试图片\n", "! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/product_detection/248347732153_1040.jpg" ] }, { "cell_type": "code", "execution_count": null, "id": "84150726-c90c-4ab4-9a5d-c46d223d403d", "metadata": {}, "outputs": [], "source": [ "import cv2\n", "from easycv.predictors.feature_extractor import TorchFeatureExtractor\n", "\n", "# 修改output_ckpt指向\n", "output_ckpt = 'work_dir/selfsup/jpg/dino_deit_small_p16/epoch_10_export.pth'\n", "fe = TorchFeatureExtractor(output_ckpt)\n", "\n", "img = cv2.imread('248347732153_1040.jpg')\n", "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", "feature = fe.predict([img])\n", "print(feature[0]['feature'].shape)\n", "print(feature[0])" ] }, { "cell_type": "markdown", "id": "6dee5721-ea78-4177-b74c-f00006e9a542", "metadata": {}, "source": [ "### 自监督预训练+ 图像分类finetune\n", "参考EasyCV图像分类的demo, 在训练时加上--load_from 参数,使用自监督预训练的模型权重, 注意这里不需要使用" ] }, { "cell_type": "code", "execution_count": null, "id": "b1e6ae4f-7a1c-416a-ae65-3e756c1d8551", "metadata": {}, "outputs": [], "source": [ "! python -m easycv.tools.train r50.py --work_dir work_dirs/classification/cifar10/dino_deit_small_p16 --load_from work_dir/selfsup/jpg/dino_deit_small_p16/epoch_10.pth" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.12" } }, "nbformat": 4, "nbformat_minor": 5 }