adapted dataset and loss

pull/1917/head
zhiboniu 2022-05-12 06:44:54 +00:00
parent aa8f4c16d2
commit 26d5b7d1cc
12 changed files with 42 additions and 170 deletions

View File

@ -1,4 +1,4 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -14,7 +14,6 @@ Global:
image_shape: [3, 256, 192]
save_inference_dir: "./inference"
use_multilabel: True
metric_attr: True
# model architecture
Arch:
@ -26,11 +25,15 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- BCELoss:
- MultiLabelLoss:
weight: 1.0
weight_ratio: True
size_sum: True
Eval:
- BCELoss:
- MultiLabelLoss:
weight: 1.0
weight_ratio: True
size_sum: True
Optimizer:
name: Adam
@ -47,10 +50,10 @@ Optimizer:
DataLoader:
Train:
dataset:
name: AttrDataset
name: MultiLabelDataset
image_root: "dataset/xingrenfenxi/data/"
cls_label_path: "dataset/xingrenfenxi/all_qiye.pkl"
split: 'trainval'
cls_label_path: "dataset/xingrenfenxi/trainval.txt"
label_ratio: True
transform_ops:
- DecodeImage:
to_rgb: True
@ -80,10 +83,10 @@ DataLoader:
use_shared_memory: True
Eval:
dataset:
name: AttrDataset
name: MultiLabelDataset
image_root: "dataset/xingrenfenxi/data/"
cls_label_path: "dataset/xingrenfenxi/all_qiye.pkl"
split: 'test'
cls_label_path: "dataset/xingrenfenxi/test.txt"
label_ratio: True
transform_ops:
- DecodeImage:
to_rgb: True

View File

@ -30,7 +30,6 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.attr_dataset import AttrDataset
# sampler

View File

@ -10,4 +10,3 @@ from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.attr_dataset import AttrDataset

View File

@ -1,82 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import os
import pickle
from .common_dataset import CommonDataset
from ppcls.data.preprocess import transform
class AttrDataset(CommonDataset):
def _load_anno(self, seed=None, split='trainval'):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
anno_path = self._cls_path
image_dir = self._img_root
self.images = []
self.labels = []
dataset_info = pickle.load(open(anno_path, 'rb+'))
img_id = dataset_info.image_name
attr_label = dataset_info.label
attr_label[attr_label == 2] = 0
attr_id = dataset_info.attr_name
if 'label_idx' in dataset_info.keys():
eval_attr_idx = dataset_info.label_idx.eval
attr_label = attr_label[:, eval_attr_idx]
attr_id = [attr_id[i] for i in eval_attr_idx]
attr_num = len(attr_id)
# mapping category name to class id
# first_class:0, second_class:1, ...
cname2cid = {attr_id[i]: i for i in range(attr_num)}
assert split in dataset_info.partition.keys(
), f'split {split} is not exist'
img_idx = dataset_info.partition[split]
if isinstance(img_idx, list):
img_idx = img_idx[0] # default partition 0
img_num = img_idx.shape[0]
img_id = [img_id[i] for i in img_idx]
label = attr_label[img_idx] # [:, [0, 12]]
self.label_ratio = label.mean(0)
print("label_ratio:", self.label_ratio)
for i, (img_i, label_i) in enumerate(zip(img_id, label)):
imgname = os.path.join(image_dir, img_i)
self.images.append(imgname)
self.labels.append(np.int64(label_i))
def __getitem__(self, idx):
try:
with open(self.images[idx], 'rb') as f:
img = f.read()
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, [self.labels[idx], self.label_ratio])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)

View File

@ -48,7 +48,7 @@ class CommonDataset(Dataset):
image_root,
cls_label_path,
transform_ops=None,
split='trainval'):
label_ratio=False):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
@ -56,7 +56,10 @@ class CommonDataset(Dataset):
self.images = []
self.labels = []
self._load_anno(split=split)
if label_ratio:
self.label_ratio = self._load_anno(label_ratio=label_ratio)
else:
self._load_anno()
def _load_anno(self):
pass

View File

@ -25,7 +25,7 @@ from .common_dataset import CommonDataset
class MultiLabelDataset(CommonDataset):
def _load_anno(self):
def _load_anno(self, label_ratio=False):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
self.images = []
@ -41,6 +41,8 @@ class MultiLabelDataset(CommonDataset):
self.labels.append(labels)
assert os.path.exists(self.images[-1])
if label_ratio:
return np.array(self.labels).mean(0)
def __getitem__(self, idx):
try:
@ -50,7 +52,10 @@ class MultiLabelDataset(CommonDataset):
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
label = np.array(self.labels[idx]).astype("float32")
return (img, label)
if self.label_ratio is not None:
return (img, [label, self.label_ratio])
else:
return (img, label)
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".

View File

@ -32,8 +32,8 @@ def classification_eval(engine, epoch_id=0):
}
print_batch_step = engine.config["Global"]["print_batch_step"]
if engine.eval_metric_func is not None and engine.config["Global"][
"metric_attr"]:
if engine.eval_metric_func is not None and engine.config["Arch"][
"name"] == "StrongBaselineAttr":
output_info["attr"] = AttrMeter(threshold=0.5)
metric_key = None
@ -128,7 +128,7 @@ def classification_eval(engine, epoch_id=0):
# calc metric
if engine.eval_metric_func is not None:
if engine.config["Global"]["metric_attr"]:
if engine.config["Arch"]["name"] == "StrongBaselineAttr":
metric_dict = engine.eval_metric_func(preds, labels)
metric_key = "attr"
output_info["attr"].update(metric_dict)
@ -153,7 +153,7 @@ def classification_eval(engine, epoch_id=0):
ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg)
if engine.config["Global"]["metric_attr"]:
if engine.config["Arch"]["name"] == "StrongBaselineAttr":
metric_msg = ""
else:
metric_msg = ", ".join([
@ -168,7 +168,7 @@ def classification_eval(engine, epoch_id=0):
if engine.use_dali:
engine.eval_dataloader.reset()
if engine.config["Global"]["metric_attr"]:
if engine.config["Arch"]["name"] == "StrongBaselineAttr":
metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*output_info["attr"].res())

View File

@ -26,7 +26,6 @@ from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss
from .bceloss import BCELoss
from .deephashloss import DSHSDLoss
from .deephashloss import LCDSHLoss

View File

@ -1,59 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def ratio2weight(targets, ratio):
# print(targets)
pos_weights = targets * (1. - ratio)
neg_weights = (1. - targets) * ratio
weights = paddle.exp(neg_weights + pos_weights)
# for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1
weights = weights - weights * (targets > 1)
return weights
class BCELoss(nn.Layer):
"""BCE Loss.
Args:
"""
def __init__(self,
sample_weight=True,
size_sum=True,
smoothing=None,
weight=1.0):
super(BCELoss, self).__init__()
self.sample_weight = sample_weight
self.size_sum = size_sum
self.hyper = 0.8
self.smoothing = smoothing
def forward(self, logits, labels):
targets, ratio = labels
if self.smoothing is not None:
targets = (1 - self.smoothing) * targets + self.smoothing * (
1 - targets)
targets = paddle.cast(targets, 'float32')
loss_m = F.binary_cross_entropy_with_logits(
logits, targets, reduction='none')
targets_mask = paddle.cast(targets > 0.5, 'float32')
if self.sample_weight:
weight = ratio2weight(targets_mask, ratio[0])
weight = weight * (targets > -1)
loss_m = loss_m * weight
loss = loss_m.sum(1).mean() if self.size_sum else loss_m.sum()
return {"BCELoss": loss}

View File

@ -19,12 +19,13 @@ class MultiLabelLoss(nn.Layer):
Multi-label loss
"""
def __init__(self, epsilon=None, weight_ratio=None):
def __init__(self, epsilon=None, size_sum=False, weight_ratio=False):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
self.weight_ratio = weight_ratio
self.size_sum = size_sum
def _labelsmoothing(self, target, class_num):
if target.ndim == 1 or target.shape[-1] != class_num:
@ -36,18 +37,21 @@ class MultiLabelLoss(nn.Layer):
return soft_target
def _binary_crossentropy(self, input, target, class_num):
if self.weight_ratio:
target, label_ratio = target
if self.epsilon is not None:
target = self._labelsmoothing(target, class_num)
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
cost = F.binary_cross_entropy_with_logits(
logit=input, label=target, reduction='none')
if self.weight_ratio is not None:
if self.weight_ratio:
targets_mask = paddle.cast(target > 0.5, 'float32')
weight = ratio2weight(targets_mask,
paddle.to_tensor(self.weight_ratio))
weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio))
weight = weight * (target > -1)
cost = cost * weight
import pdb
pdb.set_trace()
if self.size_sum:
cost = cost.sum(1).mean() if self.size_sum else cost.mean()
return cost

View File

@ -9,3 +9,4 @@ scipy
scikit-learn==0.23.2
gast==0.3.3
faiss-cpu==1.7.1.post2
easydict