add multilabel feature

pull/1265/head
cuicheng01 2021-09-26 07:05:13 +00:00
parent e431fe3399
commit af9aae730e
20 changed files with 524 additions and 69 deletions

View File

@ -0,0 +1,33 @@
Global:
infer_imgs: "./images/0517_2715693311.jpg"
inference_model_dir: "../inference/"
batch_size: 1
use_gpu: True
enable_mkldnn: False
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
PreProcess:
transform_ops:
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
channel_num: 3
- ToCHWImage:
PostProcess:
main_indicator: MultiLabelTopk
MultiLabelTopk:
topk: 5
class_id_map_file: None
SavePreLabel:
save_dir: ./pre_label/

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

@ -81,12 +81,14 @@ class Topk(object):
class_id_map = None class_id_map = None
return class_id_map return class_id_map
def __call__(self, x, file_names=None): def __call__(self, x, file_names=None, multilabel=False):
if file_names is not None: if file_names is not None:
assert x.shape[0] == len(file_names) assert x.shape[0] == len(file_names)
y = [] y = []
for idx, probs in enumerate(x): for idx, probs in enumerate(x):
index = probs.argsort(axis=0)[-self.topk:][::-1].astype("int32") index = probs.argsort(axis=0)[-self.topk:][::-1].astype(
"int32") if not multilabel else np.where(
probs >= 0.5)[0].astype("int32")
clas_id_list = [] clas_id_list = []
score_list = [] score_list = []
label_name_list = [] label_name_list = []
@ -108,6 +110,14 @@ class Topk(object):
return y return y
class MultiLabelTopk(Topk):
def __init__(self, topk=1, class_id_map_file=None):
super().__init__()
def __call__(self, x, file_names=None):
return super().__call__(x, file_names, multilabel=True)
class SavePreLabel(object): class SavePreLabel(object):
def __init__(self, save_dir): def __init__(self, save_dir):
if save_dir is None: if save_dir is None:
@ -128,23 +138,24 @@ class SavePreLabel(object):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
shutil.copy(image_file, output_dir) shutil.copy(image_file, output_dir)
class Binarize(object): class Binarize(object):
def __init__(self, method = "round"): def __init__(self, method="round"):
self.method = method self.method = method
self.unit = np.array([[128, 64, 32, 16, 8, 4, 2, 1]]).T self.unit = np.array([[128, 64, 32, 16, 8, 4, 2, 1]]).T
def __call__(self, x, file_names=None): def __call__(self, x, file_names=None):
if self.method == "round": if self.method == "round":
x = np.round(x + 1).astype("uint8") - 1 x = np.round(x + 1).astype("uint8") - 1
if self.method == "sign": if self.method == "sign":
x = ((np.sign(x) + 1) / 2).astype("uint8") x = ((np.sign(x) + 1) / 2).astype("uint8")
embedding_size = x.shape[1] embedding_size = x.shape[1]
assert embedding_size % 8 == 0, "The Binary index only support vectors with sizes multiple of 8" assert embedding_size % 8 == 0, "The Binary index only support vectors with sizes multiple of 8"
byte = np.zeros([x.shape[0], embedding_size // 8], dtype=np.uint8) byte = np.zeros([x.shape[0], embedding_size // 8], dtype=np.uint8)
for i in range(embedding_size // 8): for i in range(embedding_size // 8):
byte[:, i:i+1] = np.dot(x[:, i * 8: (i + 1)* 8], self.unit) byte[:, i:i + 1] = np.dot(x[:, i * 8:(i + 1) * 8], self.unit)
return byte return byte

View File

@ -71,7 +71,6 @@ class ClsPredictor(Predictor):
output_names = self.paddle_predictor.get_output_names() output_names = self.paddle_predictor.get_output_names()
output_tensor = self.paddle_predictor.get_output_handle(output_names[ output_tensor = self.paddle_predictor.get_output_handle(output_names[
0]) 0])
if self.benchmark: if self.benchmark:
self.auto_logger.times.start() self.auto_logger.times.start()
if not isinstance(images, (list, )): if not isinstance(images, (list, )):
@ -119,7 +118,6 @@ def main(config):
) == len(image_list): ) == len(image_list):
if len(batch_imgs) == 0: if len(batch_imgs) == 0:
continue continue
batch_results = cls_predictor.predict(batch_imgs) batch_results = cls_predictor.predict(batch_imgs)
for number, result_dict in enumerate(batch_results): for number, result_dict in enumerate(batch_results):
filename = batch_names[number] filename = batch_names[number]

View File

@ -1,6 +1,9 @@
# classification # classification
python3.7 python/predict_cls.py -c configs/inference_cls.yaml python3.7 python/predict_cls.py -c configs/inference_cls.yaml
# multilabel_classification
#python3.7 python/predict_cls.py -c configs/inference_multilabel_cls.yaml
# feature extractor # feature extractor
# python3.7 python/predict_rec.py -c configs/inference_rec.yaml # python3.7 python/predict_rec.py -c configs/inference_rec.yaml

View File

@ -25,58 +25,66 @@ tar -xf NUS-SCENE-dataset.tar
cd ../../ cd ../../
``` ```
## 二、环境准备 ## 二、模型训练
### 2.1 下载预训练模型
本例展示基于ResNet50_vd模型的多标签分类流程因此首先下载ResNet50_vd的预训练模型
```bash
mkdir pretrained
cd pretrained
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams
cd ../
```
## 三、模型训练
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch \ python3 -m paddle.distributed.launch \
--gpus="0" \ --gpus="0,1,2,3" \
tools/train.py \ tools/train.py \
-c ./configs/quick_start/ResNet50_vd_multilabel.yaml -c ./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
``` ```
训练10epoch之后验证集最好的正确率应该在0.72左右。 训练10epoch之后验证集最好的正确率应该在0.95左右。
## 、模型评估 ## 三、模型评估
```bash ```bash
python tools/eval.py \ python3 tools/eval.py \
-c ./configs/quick_start/ResNet50_vd_multilabel.yaml \ -c ./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml \
-o pretrained_model="./output/ResNet50_vd/best_model/ppcls" \ -o Arch.pretrained="./output/MobileNetV1/best_model"
-o load_static_weights=False
``` ```
评估指标采用mAP验证集的mAP应该在0.57左右。 ## 四、模型预测
## 五、模型预测
```bash ```bash
python tools/infer/infer.py \ python3 tools/infer.py \
-i "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0199_434752251.jpg" \ -c ./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml \
--model ResNet50_vd \ -o Arch.pretrained="./output/MobileNetV1/best_model"
--pretrained_model "./output/ResNet50_vd/best_model/ppcls" \
--use_gpu True \
--load_static_weights False \
--multilabel True \
--class_num 33
``` ```
得到类似下面的输出: 得到类似下面的输出:
``` ```
class id: 3, probability: 0.6025 [{'class_ids': [6, 13, 17, 23, 26, 30], 'scores': [0.95683, 0.5567, 0.55211, 0.99088, 0.5943, 0.78767], 'file_name': './deploy/images/0517_2715693311.jpg', 'label_names': []}]
class id: 23, probability: 0.5491 ```
class id: 32, probability: 0.7006
``` ## 五、基于预测引擎预测
### 5.1 导出inference model
```bash
python3 tools/export_model.py \
-c ./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml \
-o Arch.pretrained="./output/MobileNetV1/best_model"
```
inference model的路径默认在当前路径下`./inference`
### 5.2 基于预测引擎预测
首先进入deploy目录下
```bash
cd ./deploy
```
通过预测引擎推理预测:
```
python3 python/predict_cls.py \
-c configs/inference_multilabel_cls.yaml
```
得到类似下面的输出:
```
0517_2715693311.jpg: class id(s): [6, 13, 17, 23, 26, 30], score(s): [0.96, 0.56, 0.55, 0.99, 0.59, 0.79], label_name(s): []
```

View File

@ -0,0 +1,129 @@
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 10
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
use_multilabel: True
# model architecture
Arch:
name: MobileNetV1
class_num: 33
pretrained: True
# loss function config for traing/eval process
Loss:
Train:
- MultiLabelLoss:
weight: 1.0
Eval:
- MultiLabelLoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.1
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiLabelDataset
image_root: ./dataset/NUS-SCENE-dataset/images/
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: MultiLabelDataset
image_root: ./dataset/NUS-SCENE-dataset/images/
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_test_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: dataset/NUS-SCENE-dataset/images/0001_109549716.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: MutiLabelTopk
topk: 5
class_id_map_file: None
Metric:
Train:
- HammingDistance:
- AccuracyScore:
Eval:
- HammingDistance:
- AccuracyScore:

View File

@ -0,0 +1,129 @@
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 10
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
use_multilabel: True
# model architecture
Arch:
name: MobileNetV1
class_num: 33
pretrained: True
# loss function config for traing/eval process
Loss:
Train:
- MultiLabelLoss:
weight: 1.0
Eval:
- MultiLabelLoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.1
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiLabelDataset
image_root: ./dataset/NUS-SCENE-dataset/images/
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: MultiLabelDataset
image_root: ./dataset/NUS-SCENE-dataset/images/
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_test_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./deploy/images/0517_2715693311.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: MultiLabelTopk
topk: 5
class_id_map_file: None
Metric:
Train:
- HammingDistance:
- AccuracyScore:
Eval:
- HammingDistance:
- AccuracyScore:

View File

@ -33,7 +33,7 @@ class MultiLabelDataset(CommonDataset):
with open(self._cls_path) as fd: with open(self._cls_path) as fd:
lines = fd.readlines() lines = fd.readlines()
for l in lines: for l in lines:
l = l.strip().split(" ") l = l.strip().split("\t")
self.images.append(os.path.join(self._img_root, l[0])) self.images.append(os.path.join(self._img_root, l[0]))
labels = l[1].split(',') labels = l[1].split(',')
@ -44,13 +44,14 @@ class MultiLabelDataset(CommonDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
try: try:
img = cv2.imread(self.images[idx]) with open(self.images[idx], 'rb') as f:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = f.read()
if self._transform_ops: if self._transform_ops:
img = transform(img, self._transform_ops) img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
label = np.array(self.labels[idx]).astype("float32") label = np.array(self.labels[idx]).astype("float32")
return (img, label) return (img, label)
except Exception as ex: except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}". logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex)) format(self.images[idx], ex))

View File

@ -16,7 +16,7 @@ import importlib
from . import topk from . import topk
from .topk import Topk from .topk import Topk, MultiLabelTopk
def build_postprocess(config): def build_postprocess(config):

View File

@ -45,15 +45,17 @@ class Topk(object):
class_id_map = None class_id_map = None
return class_id_map return class_id_map
def __call__(self, x, file_names=None): def __call__(self, x, file_names=None, multilabel=False):
assert isinstance(x, paddle.Tensor) assert isinstance(x, paddle.Tensor)
if file_names is not None: if file_names is not None:
assert x.shape[0] == len(file_names) assert x.shape[0] == len(file_names)
x = F.softmax(x, axis=-1) x = F.softmax(x, axis=-1) if not multilabel else F.sigmoid(x)
x = x.numpy() x = x.numpy()
y = [] y = []
for idx, probs in enumerate(x): for idx, probs in enumerate(x):
index = probs.argsort(axis=0)[-self.topk:][::-1].astype("int32") index = probs.argsort(axis=0)[-self.topk:][::-1].astype(
"int32") if not multilabel else np.where(
probs >= 0.5)[0].astype("int32")
clas_id_list = [] clas_id_list = []
score_list = [] score_list = []
label_name_list = [] label_name_list = []
@ -73,3 +75,11 @@ class Topk(object):
result["label_names"] = label_name_list result["label_names"] = label_name_list
y.append(result) y.append(result)
return y return y
class MultiLabelTopk(Topk):
def __init__(self, topk=1, class_id_map_file=None):
super().__init__()
def __call__(self, x, file_names=None):
return super().__call__(x, file_names, multilabel=True)

View File

@ -355,7 +355,8 @@ class Engine(object):
def export(self): def export(self):
assert self.mode == "export" assert self.mode == "export"
model = ExportModel(self.config["Arch"], self.model) use_multilabel = self.config["Global"].get("use_multilabel", False)
model = ExportModel(self.config["Arch"], self.model, use_multilabel)
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model, load_dygraph_pretrain(model.base_model,
self.config["Global"]["pretrained_model"]) self.config["Global"]["pretrained_model"])
@ -388,10 +389,9 @@ class ExportModel(nn.Layer):
ExportModel: add softmax onto the model ExportModel: add softmax onto the model
""" """
def __init__(self, config, model): def __init__(self, config, model, use_multilabel):
super().__init__() super().__init__()
self.base_model = model self.base_model = model
# we should choose a final model to export # we should choose a final model to export
if isinstance(self.base_model, DistillationModel): if isinstance(self.base_model, DistillationModel):
self.infer_model_name = config["infer_model_name"] self.infer_model_name = config["infer_model_name"]
@ -402,10 +402,13 @@ class ExportModel(nn.Layer):
if self.infer_output_key == "features" and isinstance(self.base_model, if self.infer_output_key == "features" and isinstance(self.base_model,
RecModel): RecModel):
self.base_model.head = IdentityHead() self.base_model.head = IdentityHead()
if config.get("infer_add_softmax", True): if use_multilabel:
self.softmax = nn.Softmax(axis=-1) self.out_act = nn.Sigmoid()
else: else:
self.softmax = None if config.get("infer_add_softmax", True):
self.out_act = nn.Softmax(axis=-1)
else:
self.out_act = None
def eval(self): def eval(self):
self.training = False self.training = False
@ -421,6 +424,6 @@ class ExportModel(nn.Layer):
x = x[self.infer_model_name] x = x[self.infer_model_name]
if self.infer_output_key is not None: if self.infer_output_key is not None:
x = x[self.infer_output_key] x = x[self.infer_output_key]
if self.softmax is not None: if self.out_act is not None:
x = self.softmax(x) x = self.out_act(x)
return x return x

View File

@ -52,7 +52,8 @@ def classification_eval(evaler, epoch_id=0):
time_info["reader_cost"].update(time.time() - tic) time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = batch[1].reshape([-1, 1]).astype("int64") if not evaler.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
out = evaler.model(batch[0]) out = evaler.model(batch[0])
# calc loss # calc loss

View File

@ -36,8 +36,8 @@ def train_epoch(trainer, epoch_id, print_batch_step):
paddle.to_tensor(batch[0]['label']) paddle.to_tensor(batch[0]['label'])
] ]
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[1] = batch[1].reshape([-1, 1]).astype("int64") if not trainer.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
trainer.global_step += 1 trainer.global_step += 1
# image input # image input
if trainer.amp: if trainer.amp:

View File

@ -20,6 +20,7 @@ from .distanceloss import DistanceLoss
from .distillationloss import DistillationCELoss from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss from .distillationloss import DistillationDMLLoss
from .multilabelloss import MultiLabelLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):

View File

@ -0,0 +1,43 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class MultiLabelLoss(nn.Layer):
"""
Multi-label loss
"""
def __init__(self, epsilon=None):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.ndim == 1 or target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def _binary_crossentropy(self, input, target, class_num):
if self.epsilon is not None:
target = self._labelsmoothing(target, class_num)
cost = F.binary_cross_entropy_with_logits(
logit=input, label=target)
else:
cost = F.binary_cross_entropy_with_logits(
logit=input, label=target)
return cost
def forward(self, x, target):
if isinstance(x, dict):
x = x["logits"]
class_num = x.shape[-1]
loss = self._binary_crossentropy(x, target, class_num)
loss = loss.mean()
return {"MultiLabelLoss": loss}

View File

@ -19,6 +19,8 @@ from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
from .metrics import DistillationTopkAcc from .metrics import DistillationTopkAcc
from .metrics import GoogLeNetTopkAcc from .metrics import GoogLeNetTopkAcc
from .metrics import HammingDistance, AccuracyScore
class CombinedMetrics(nn.Layer): class CombinedMetrics(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
@ -32,7 +34,8 @@ class CombinedMetrics(nn.Layer):
metric_name = list(config)[0] metric_name = list(config)[0]
metric_params = config[metric_name] metric_params = config[metric_name]
if metric_params is not None: if metric_params is not None:
self.metric_func_list.append(eval(metric_name)(**metric_params)) self.metric_func_list.append(
eval(metric_name)(**metric_params))
else: else:
self.metric_func_list.append(eval(metric_name)()) self.metric_func_list.append(eval(metric_name)())
@ -42,6 +45,7 @@ class CombinedMetrics(nn.Layer):
metric_dict.update(metric_func(*args, **kwargs)) metric_dict.update(metric_func(*args, **kwargs))
return metric_dict return metric_dict
def build_metrics(config): def build_metrics(config):
metrics_list = CombinedMetrics(copy.deepcopy(config)) metrics_list = CombinedMetrics(copy.deepcopy(config))
return metrics_list return metrics_list

View File

@ -15,6 +15,12 @@
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F
from sklearn.metrics import hamming_loss
from sklearn.metrics import accuracy_score as accuracy_metric
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.preprocessing import binarize
class TopkAcc(nn.Layer): class TopkAcc(nn.Layer):
@ -198,7 +204,7 @@ class Precisionk(nn.Layer):
equal_flag = paddle.logical_and(equal_flag, equal_flag = paddle.logical_and(equal_flag,
keep_mask.astype('bool')) keep_mask.astype('bool'))
equal_flag = paddle.cast(equal_flag, 'float32') equal_flag = paddle.cast(equal_flag, 'float32')
Ns = paddle.arange(gallery_img_id.shape[0]) + 1 Ns = paddle.arange(gallery_img_id.shape[0]) + 1
equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1) equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1)
Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy() Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy()
@ -232,3 +238,71 @@ class GoogLeNetTopkAcc(TopkAcc):
def forward(self, x, label): def forward(self, x, label):
return super().forward(x[0], label) return super().forward(x[0], label)
class MutiLabelMetric(object):
def __init__(self):
pass
def _multi_hot_encode(self, logits, threshold=0.5):
return binarize(logits, threshold=threshold)
def __call__(self, output):
output = F.sigmoid(output)
preds = self._multi_hot_encode(logits=output.numpy(), threshold=0.5)
return preds
class HammingDistance(MutiLabelMetric):
"""
Soft metric based label for multilabel classification
Returns:
The smaller the return value is, the better model is.
"""
def __init__(self):
super().__init__()
def __call__(self, output, target):
preds = super().__call__(output)
metric_dict = dict()
metric_dict["HammingDistance"] = paddle.to_tensor(
hamming_loss(target, preds))
return metric_dict
class AccuracyScore(MutiLabelMetric):
"""
Hard metric for multilabel classification
Args:
base: ["sample", "label"], default="sample"
if "sample", return metric score based sample,
if "label", return metric score based label.
Returns:
accuracy:
"""
def __init__(self, base="label"):
super().__init__()
assert base in ["sample", "label"
], 'must be one of ["sample", "label"]'
self.base = base
def __call__(self, output, target):
preds = super().__call__(output)
metric_dict = dict()
if self.base == "sample":
accuracy = accuracy_metric(target, preds)
elif self.base == "label":
mcm = multilabel_confusion_matrix(target, preds)
tns = mcm[:, 0, 0]
fns = mcm[:, 1, 0]
tps = mcm[:, 1, 1]
fps = mcm[:, 0, 1]
accuracy = (sum(tps) + sum(tns)) / (
sum(tps) + sum(tns) + sum(fns) + sum(fps))
precision = sum(tps) / (sum(tps) + sum(fps))
recall = sum(tps) / (sum(tps) + sum(fns))
F1 = 2 * (accuracy * recall) / (accuracy + recall)
metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy)
return metric_dict

View File

@ -4,4 +4,4 @@
# python3.7 tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml # python3.7 tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml
# for multi-cards train # for multi-cards train
python3.7 -m paddle.distributed.launch --gpus="0,1,2,3" tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml python3.7 -m paddle.distributed.launch --gpus="0,1,2,3" tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml

7
train.sh 100755
View File

@ -0,0 +1,7 @@
#!/usr/bin/env bash
# for single card train
# python3.7 tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml
# for multi-cards train
python3.7 -m paddle.distributed.launch --gpus="0" tools/train.py -c ./MobileNetV2.yaml