PaddleClas/docs/zh_CN/advanced_tutorials/multilabel/multilabel.md

2.1 KiB
Raw Blame History

多标签分类quick start

基于NUS-WIDE-SCENE数据集体验多标签分类的训练、评估、预测的过程该数据集是NUS-WIDE数据集的一个子集。请事先参考安装指南配置运行环境和克隆PaddleClas代码。

一、数据和模型准备

  • 进入PaddleClas目录。
cd path_to_PaddleClas
  • 创建并进入dataset/NUS-WIDE-SCENE目录下载并解压NUS-WIDE-SCENE数据集。
mkdir dataset/NUS-WIDE-SCENE
cd dataset/NUS-WIDE-SCENE
wget https://paddle-imagenet-models-name.bj.bcebos.com/data/NUS-SCENE-dataset.tar
tar -xf NUS-SCENE-dataset.tar
  • 返回PaddleClas根目录
cd ../../

二、环境准备

2.1 下载预训练模型

本例展示基于ResNet50_vd模型的多标签分类流程因此首先下载ResNet50_vd的预训练模型

mkdir pretrained
cd pretrained
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams
cd ../

三、模型训练

export CUDA_VISIBLE_DEVICES=0
python -m paddle.distributed.launch \
    --gpus="0" \
    tools/train.py \
        -c ./configs/quick_start/ResNet50_vd_multilabel.yaml

训练10epoch之后验证集最好的正确率应该在0.72左右。

四、模型评估

python tools/eval.py \
    -c ./configs/quick_start/ResNet50_vd_multilabel.yaml \
    -o pretrained_model="./output/ResNet50_vd/best_model/ppcls" \
    -o load_static_weights=False

评估指标采用mAP验证集的mAP应该在0.57左右。

五、模型预测

python tools/infer/infer.py \
    -i "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0199_434752251.jpg" \
    --model ResNet50_vd \
    --pretrained_model "./output/ResNet50_vd/best_model/ppcls" \
    --use_gpu True \
    --load_static_weights False \
    --multilabel True \
    --class_num 33

得到类似下面的输出:

    class id: 3, probability: 0.6025
    class id: 23, probability: 0.5491
    class id: 32, probability: 0.7006