commit
4091592cfb
|
@ -20,9 +20,10 @@ import numpy as np
|
|||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, BatchNorm2D
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import Uniform
|
||||
from paddle.regularizer import L2Decay
|
||||
import math
|
||||
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
|
@ -132,11 +133,12 @@ class ConvBNLayer(TheseusLayer):
|
|||
weight_attr=ParamAttr(learning_rate=lr_mult),
|
||||
bias_attr=False,
|
||||
data_format=data_format)
|
||||
self.bn = BatchNorm(
|
||||
num_filters,
|
||||
param_attr=ParamAttr(learning_rate=lr_mult),
|
||||
bias_attr=ParamAttr(learning_rate=lr_mult),
|
||||
data_layout=data_format)
|
||||
|
||||
weight_attr = ParamAttr(learning_rate=lr_mult, trainable=True)
|
||||
bias_attr = ParamAttr(learning_rate=lr_mult, trainable=True)
|
||||
|
||||
self.bn = BatchNorm2D(
|
||||
num_filters, weight_attr=weight_attr, bias_attr=bias_attr)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -192,6 +194,7 @@ class BottleneckBlock(TheseusLayer):
|
|||
is_vd_mode=False if if_first else True,
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.shortcut = shortcut
|
||||
|
||||
|
@ -312,7 +315,7 @@ class ResNet(TheseusLayer):
|
|||
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
|
||||
}
|
||||
|
||||
self.stem = nn.Sequential(*[
|
||||
self.stem = nn.Sequential(* [
|
||||
ConvBNLayer(
|
||||
num_channels=in_c,
|
||||
num_filters=out_c,
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
save_interval: 5
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 30
|
||||
print_batch_step: 20
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 256, 192]
|
||||
save_inference_dir: "./inference"
|
||||
use_multilabel: True
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "ResNet50"
|
||||
pretrained: True
|
||||
class_num: 26
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MultiLabelLoss:
|
||||
weight: 1.0
|
||||
weight_ratio: True
|
||||
size_sum: True
|
||||
Eval:
|
||||
- MultiLabelLoss:
|
||||
weight: 1.0
|
||||
weight_ratio: True
|
||||
size_sum: True
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [12, 18, 24, 28]
|
||||
values: [0.0001, 0.00001, 0.000001, 0.0000001]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
clip_norm: 10
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: MultiLabelDataset
|
||||
image_root: "dataset/attribute/data/"
|
||||
cls_label_path: "dataset/attribute/trainval.txt"
|
||||
label_ratio: True
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: [192, 256]
|
||||
- Padv2:
|
||||
size: [212, 276]
|
||||
pad_mode: 1
|
||||
fill_value: 0
|
||||
- RandomCropImage:
|
||||
size: [192, 256]
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: True
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
Eval:
|
||||
dataset:
|
||||
name: MultiLabelDataset
|
||||
image_root: "dataset/attribute/data/"
|
||||
cls_label_path: "dataset/attribute/test.txt"
|
||||
label_ratio: True
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: [192, 256]
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- ATTRMetric:
|
||||
|
||||
|
|
@ -44,11 +44,11 @@ def create_operators(params):
|
|||
|
||||
|
||||
class CommonDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_root,
|
||||
cls_label_path,
|
||||
transform_ops=None, ):
|
||||
def __init__(self,
|
||||
image_root,
|
||||
cls_label_path,
|
||||
transform_ops=None,
|
||||
label_ratio=False):
|
||||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path
|
||||
if transform_ops:
|
||||
|
@ -56,7 +56,10 @@ class CommonDataset(Dataset):
|
|||
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self._load_anno()
|
||||
if label_ratio:
|
||||
self.label_ratio = self._load_anno(label_ratio=label_ratio)
|
||||
else:
|
||||
self._load_anno()
|
||||
|
||||
def _load_anno(self):
|
||||
pass
|
||||
|
|
|
@ -25,7 +25,7 @@ from .common_dataset import CommonDataset
|
|||
|
||||
|
||||
class MultiLabelDataset(CommonDataset):
|
||||
def _load_anno(self):
|
||||
def _load_anno(self, label_ratio=False):
|
||||
assert os.path.exists(self._cls_path)
|
||||
assert os.path.exists(self._img_root)
|
||||
self.images = []
|
||||
|
@ -41,6 +41,8 @@ class MultiLabelDataset(CommonDataset):
|
|||
|
||||
self.labels.append(labels)
|
||||
assert os.path.exists(self.images[-1])
|
||||
if label_ratio:
|
||||
return np.array(self.labels).mean(0).astype("float32")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
|
@ -50,7 +52,10 @@ class MultiLabelDataset(CommonDataset):
|
|||
img = transform(img, self._transform_ops)
|
||||
img = img.transpose((2, 0, 1))
|
||||
label = np.array(self.labels[idx]).astype("float32")
|
||||
return (img, label)
|
||||
if self.label_ratio is not None:
|
||||
return (img, np.array([label, self.label_ratio]))
|
||||
else:
|
||||
return (img, label)
|
||||
|
||||
except Exception as ex:
|
||||
logger.error("Exception occured when parse line: {} with msg: {}".
|
||||
|
|
|
@ -33,6 +33,8 @@ from ppcls.data.preprocess.ops.operators import AugMix
|
|||
from ppcls.data.preprocess.ops.operators import Pad
|
||||
from ppcls.data.preprocess.ops.operators import ToTensor
|
||||
from ppcls.data.preprocess.ops.operators import Normalize
|
||||
from ppcls.data.preprocess.ops.operators import RandomCropImage
|
||||
from ppcls.data.preprocess.ops.operators import Padv2
|
||||
|
||||
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
|
||||
|
||||
|
@ -40,6 +42,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
import random
|
||||
|
||||
|
||||
def transform(data, ops=[]):
|
||||
""" transform """
|
||||
for op in ops:
|
||||
|
|
|
@ -190,6 +190,105 @@ class CropImage(object):
|
|||
return img[h_start:h_end, w_start:w_end, :]
|
||||
|
||||
|
||||
class Padv2(object):
|
||||
def __init__(self,
|
||||
size=None,
|
||||
size_divisor=32,
|
||||
pad_mode=0,
|
||||
offsets=None,
|
||||
fill_value=(127.5, 127.5, 127.5)):
|
||||
"""
|
||||
Pad image to a specified size or multiple of size_divisor.
|
||||
Args:
|
||||
size (int, list): image target size, if None, pad to multiple of size_divisor, default None
|
||||
size_divisor (int): size divisor, default 32
|
||||
pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
|
||||
if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
|
||||
offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
|
||||
fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
|
||||
"""
|
||||
|
||||
if not isinstance(size, (int, list)):
|
||||
raise TypeError(
|
||||
"Type of target_size is invalid when random_size is True. \
|
||||
Must be List, now is {}".format(type(size)))
|
||||
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
|
||||
assert pad_mode in [
|
||||
-1, 0, 1, 2
|
||||
], 'currently only supports four modes [-1, 0, 1, 2]'
|
||||
if pad_mode == -1:
|
||||
assert offsets, 'if pad_mode is -1, offsets should not be None'
|
||||
|
||||
self.size = size
|
||||
self.size_divisor = size_divisor
|
||||
self.pad_mode = pad_mode
|
||||
self.fill_value = fill_value
|
||||
self.offsets = offsets
|
||||
|
||||
def apply_image(self, image, offsets, im_size, size):
|
||||
x, y = offsets
|
||||
im_h, im_w = im_size
|
||||
h, w = size
|
||||
canvas = np.ones((h, w, 3), dtype=np.float32)
|
||||
canvas *= np.array(self.fill_value, dtype=np.float32)
|
||||
canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
|
||||
return canvas
|
||||
|
||||
def __call__(self, img):
|
||||
im_h, im_w = img.shape[:2]
|
||||
if self.size:
|
||||
w, h = self.size
|
||||
assert (
|
||||
im_h <= h and im_w <= w
|
||||
), '(h, w) of target size should be greater than (im_h, im_w)'
|
||||
else:
|
||||
h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
|
||||
w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
|
||||
|
||||
if h == im_h and w == im_w:
|
||||
return img.astype(np.float32)
|
||||
|
||||
if self.pad_mode == -1:
|
||||
offset_x, offset_y = self.offsets
|
||||
elif self.pad_mode == 0:
|
||||
offset_y, offset_x = 0, 0
|
||||
elif self.pad_mode == 1:
|
||||
offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
|
||||
else:
|
||||
offset_y, offset_x = h - im_h, w - im_w
|
||||
|
||||
offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
|
||||
|
||||
return self.apply_image(img, offsets, im_size, size)
|
||||
|
||||
|
||||
class RandomCropImage(object):
|
||||
"""Random crop image only
|
||||
"""
|
||||
|
||||
def __init__(self, size):
|
||||
super(RandomCropImage, self).__init__()
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
|
||||
h, w = img.shape[:2]
|
||||
tw, th = self.size
|
||||
i = random.randint(0, h - th)
|
||||
j = random.randint(0, w - tw)
|
||||
|
||||
img = img[i:i + th, j:j + tw, :]
|
||||
if img.shape[0] != 256 or img.shape[1] != 192:
|
||||
raise ValueError('sample: ', h, w, i, j, th, tw, img.shape)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class RandCropImage(object):
|
||||
""" random crop image """
|
||||
|
||||
|
@ -463,8 +562,8 @@ class Pad(object):
|
|||
# Process fill color for affine transforms
|
||||
major_found, minor_found = (int(v)
|
||||
for v in PILLOW_VERSION.split('.')[:2])
|
||||
major_required, minor_required = (
|
||||
int(v) for v in min_pil_version.split('.')[:2])
|
||||
major_required, minor_required = (int(v) for v in
|
||||
min_pil_version.split('.')[:2])
|
||||
if major_found < major_required or (major_found == major_required and
|
||||
minor_found < minor_required):
|
||||
if fill is None:
|
||||
|
|
|
@ -82,6 +82,7 @@ def classification_eval(engine, epoch_id=0):
|
|||
# gather Tensor when distributed
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
label_list = []
|
||||
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
labels = paddle.concat(label_list, 0)
|
||||
|
||||
|
@ -123,6 +124,7 @@ def classification_eval(engine, epoch_id=0):
|
|||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
current_samples)
|
||||
|
||||
# calc metric
|
||||
if engine.eval_metric_func is not None:
|
||||
engine.eval_metric_func(preds, labels)
|
||||
|
@ -137,11 +139,14 @@ def classification_eval(engine, epoch_id=0):
|
|||
ips_msg = "ips: {:.5f} images/sec".format(
|
||||
batch_size / time_info["batch_cost"].avg)
|
||||
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].val)
|
||||
for key in output_info
|
||||
])
|
||||
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
|
||||
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
|
||||
metric_msg = ""
|
||||
else:
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].val)
|
||||
for key in output_info
|
||||
])
|
||||
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
|
||||
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
|
||||
epoch_id, iter_id,
|
||||
len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
|
||||
|
@ -149,14 +154,29 @@ def classification_eval(engine, epoch_id=0):
|
|||
tic = time.time()
|
||||
if engine.use_dali:
|
||||
engine.eval_dataloader.reset()
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].avg) for key in output_info
|
||||
])
|
||||
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
|
||||
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
||||
|
||||
# do not try to save best eval.model
|
||||
if engine.eval_metric_func is None:
|
||||
return -1
|
||||
# return 1st metric in the dict
|
||||
return engine.eval_metric_func.avg
|
||||
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
|
||||
metric_msg = ", ".join([
|
||||
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
|
||||
format(*engine.eval_metric_func.attr_res())
|
||||
])
|
||||
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
||||
|
||||
# do not try to save best eval.model
|
||||
if engine.eval_metric_func is None:
|
||||
return -1
|
||||
# return 1st metric in the dict
|
||||
return engine.eval_metric_func.attr_res()[0]
|
||||
else:
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, output_info[key].avg)
|
||||
for key in output_info
|
||||
])
|
||||
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
|
||||
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
||||
|
||||
# do not try to save best eval.model
|
||||
if engine.eval_metric_func is None:
|
||||
return -1
|
||||
# return 1st metric in the dict
|
||||
return engine.eval_metric_func.avg
|
||||
|
|
|
@ -3,16 +3,29 @@ import paddle.nn as nn
|
|||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def ratio2weight(targets, ratio):
|
||||
pos_weights = targets * (1. - ratio)
|
||||
neg_weights = (1. - targets) * ratio
|
||||
weights = paddle.exp(neg_weights + pos_weights)
|
||||
|
||||
# for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1
|
||||
weights = weights - weights * (targets > 1)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
class MultiLabelLoss(nn.Layer):
|
||||
"""
|
||||
Multi-label loss
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon=None):
|
||||
def __init__(self, epsilon=None, size_sum=False, weight_ratio=False):
|
||||
super().__init__()
|
||||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||
epsilon = None
|
||||
self.epsilon = epsilon
|
||||
self.weight_ratio = weight_ratio
|
||||
self.size_sum = size_sum
|
||||
|
||||
def _labelsmoothing(self, target, class_num):
|
||||
if target.ndim == 1 or target.shape[-1] != class_num:
|
||||
|
@ -24,13 +37,21 @@ class MultiLabelLoss(nn.Layer):
|
|||
return soft_target
|
||||
|
||||
def _binary_crossentropy(self, input, target, class_num):
|
||||
if self.weight_ratio:
|
||||
target, label_ratio = target[:, 0, :], target[:, 1, :]
|
||||
if self.epsilon is not None:
|
||||
target = self._labelsmoothing(target, class_num)
|
||||
cost = F.binary_cross_entropy_with_logits(
|
||||
logit=input, label=target)
|
||||
else:
|
||||
cost = F.binary_cross_entropy_with_logits(
|
||||
logit=input, label=target)
|
||||
cost = F.binary_cross_entropy_with_logits(
|
||||
logit=input, label=target, reduction='none')
|
||||
|
||||
if self.weight_ratio:
|
||||
targets_mask = paddle.cast(target > 0.5, 'float32')
|
||||
weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio))
|
||||
weight = weight * (target > -1)
|
||||
cost = cost * weight
|
||||
|
||||
if self.size_sum:
|
||||
cost = cost.sum(1).mean() if self.size_sum else cost.mean()
|
||||
|
||||
return cost
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
|
|||
from .metrics import DistillationTopkAcc
|
||||
from .metrics import GoogLeNetTopkAcc
|
||||
from .metrics import HammingDistance, AccuracyScore
|
||||
from .metrics import ATTRMetric
|
||||
from .metrics import TprAtFpr
|
||||
|
||||
|
||||
|
@ -55,12 +56,15 @@ class CombinedMetrics(AvgMetrics):
|
|||
def avg(self):
|
||||
return self.metric_func_list[0].avg
|
||||
|
||||
def attr_res(self):
|
||||
return self.metric_func_list[0].attrmeter.res()
|
||||
|
||||
def reset(self):
|
||||
for metric in self.metric_func_list:
|
||||
if hasattr(metric, "reset"):
|
||||
metric.reset()
|
||||
|
||||
|
||||
def build_metrics(config):
|
||||
metrics_list = CombinedMetrics(copy.deepcopy(config))
|
||||
return metrics_list
|
||||
|
||||
|
|
|
@ -22,8 +22,10 @@ from sklearn.metrics import accuracy_score as accuracy_metric
|
|||
from sklearn.metrics import multilabel_confusion_matrix
|
||||
from sklearn.preprocessing import binarize
|
||||
|
||||
from easydict import EasyDict
|
||||
|
||||
from ppcls.metric.avg_metrics import AvgMetrics
|
||||
from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils.misc import AverageMeter, AttrMeter
|
||||
|
||||
|
||||
class TopkAcc(AvgMetrics):
|
||||
|
@ -36,7 +38,10 @@ class TopkAcc(AvgMetrics):
|
|||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.avg_meters = {"top{}".format(k): AverageMeter("top{}".format(k)) for k in self.topk}
|
||||
self.avg_meters = {
|
||||
"top{}".format(k): AverageMeter("top{}".format(k))
|
||||
for k in self.topk
|
||||
}
|
||||
|
||||
def forward(self, x, label):
|
||||
if isinstance(x, dict):
|
||||
|
@ -46,7 +51,8 @@ class TopkAcc(AvgMetrics):
|
|||
for k in self.topk:
|
||||
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
|
||||
x, label, k=k)
|
||||
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)].numpy()[0], x.shape[0])
|
||||
self.avg_meters["top{}".format(k)].update(
|
||||
metric_dict["top{}".format(k)].numpy()[0], x.shape[0])
|
||||
return metric_dict
|
||||
|
||||
|
||||
|
@ -116,7 +122,7 @@ class mINP(nn.Layer):
|
|||
choosen_indices)
|
||||
equal_flag = paddle.equal(choosen_label, query_img_id)
|
||||
if keep_mask is not None:
|
||||
keep_mask = paddle.index_sample(
|
||||
keep_mask = paddle.indechmx_sample(
|
||||
keep_mask.astype('float32'), choosen_indices)
|
||||
equal_flag = paddle.logical_and(equal_flag,
|
||||
keep_mask.astype('bool'))
|
||||
|
@ -140,7 +146,7 @@ class mINP(nn.Layer):
|
|||
|
||||
|
||||
class TprAtFpr(nn.Layer):
|
||||
def __init__(self, max_fpr=1/1000.):
|
||||
def __init__(self, max_fpr=1 / 1000.):
|
||||
super().__init__()
|
||||
self.gt_pos_score_list = []
|
||||
self.gt_neg_score_list = []
|
||||
|
@ -178,14 +184,18 @@ class TprAtFpr(nn.Layer):
|
|||
threshold = i / 10000.
|
||||
if len(gt_pos_score_list) == 0:
|
||||
continue
|
||||
tpr = np.sum(gt_pos_score_list > threshold) / len(gt_pos_score_list)
|
||||
tpr = np.sum(
|
||||
gt_pos_score_list > threshold) / len(gt_pos_score_list)
|
||||
if len(gt_neg_score_list) == 0 and tpr > max_tpr:
|
||||
max_tpr = tpr
|
||||
result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(threshold, fpr, tpr)
|
||||
fpr = np.sum(gt_neg_score_list > threshold) / len(gt_neg_score_list)
|
||||
result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(
|
||||
threshold, fpr, tpr)
|
||||
fpr = np.sum(
|
||||
gt_neg_score_list > threshold) / len(gt_neg_score_list)
|
||||
if fpr <= self.max_fpr and tpr > max_tpr:
|
||||
max_tpr = tpr
|
||||
result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(threshold, fpr, tpr)
|
||||
result = "threshold: {}, fpr: {}, tpr: {:.5f}".format(
|
||||
threshold, fpr, tpr)
|
||||
self.max_tpr = max_tpr
|
||||
return result
|
||||
|
||||
|
@ -333,7 +343,8 @@ class HammingDistance(MultiLabelMetric):
|
|||
metric_dict = dict()
|
||||
metric_dict["HammingDistance"] = paddle.to_tensor(
|
||||
hamming_loss(target, preds))
|
||||
self.avg_meters["HammingDistance"].update(metric_dict["HammingDistance"].numpy()[0], output.shape[0])
|
||||
self.avg_meters["HammingDistance"].update(
|
||||
metric_dict["HammingDistance"].numpy()[0], output.shape[0])
|
||||
return metric_dict
|
||||
|
||||
|
||||
|
@ -372,5 +383,66 @@ class AccuracyScore(MultiLabelMetric):
|
|||
accuracy = (sum(tps) + sum(tns)) / (
|
||||
sum(tps) + sum(tns) + sum(fns) + sum(fps))
|
||||
metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy)
|
||||
self.avg_meters["AccuracyScore"].update(metric_dict["AccuracyScore"].numpy()[0], output.shape[0])
|
||||
self.avg_meters["AccuracyScore"].update(
|
||||
metric_dict["AccuracyScore"].numpy()[0], output.shape[0])
|
||||
return metric_dict
|
||||
|
||||
|
||||
def get_attr_metrics(gt_label, preds_probs, threshold):
|
||||
"""
|
||||
index: evaluated label index
|
||||
"""
|
||||
pred_label = (preds_probs > threshold).astype(int)
|
||||
|
||||
eps = 1e-20
|
||||
result = EasyDict()
|
||||
|
||||
has_fuyi = gt_label == -1
|
||||
pred_label[has_fuyi] = -1
|
||||
|
||||
###############################
|
||||
# label metrics
|
||||
# TP + FN
|
||||
result.gt_pos = np.sum((gt_label == 1), axis=0).astype(float)
|
||||
# TN + FP
|
||||
result.gt_neg = np.sum((gt_label == 0), axis=0).astype(float)
|
||||
# TP
|
||||
result.true_pos = np.sum((gt_label == 1) * (pred_label == 1),
|
||||
axis=0).astype(float)
|
||||
# TN
|
||||
result.true_neg = np.sum((gt_label == 0) * (pred_label == 0),
|
||||
axis=0).astype(float)
|
||||
# FP
|
||||
result.false_pos = np.sum(((gt_label == 0) * (pred_label == 1)),
|
||||
axis=0).astype(float)
|
||||
# FN
|
||||
result.false_neg = np.sum(((gt_label == 1) * (pred_label == 0)),
|
||||
axis=0).astype(float)
|
||||
|
||||
################
|
||||
# instance metrics
|
||||
result.gt_pos_ins = np.sum((gt_label == 1), axis=1).astype(float)
|
||||
result.true_pos_ins = np.sum((pred_label == 1), axis=1).astype(float)
|
||||
# true positive
|
||||
result.intersect_pos = np.sum((gt_label == 1) * (pred_label == 1),
|
||||
axis=1).astype(float)
|
||||
# IOU
|
||||
result.union_pos = np.sum(((gt_label == 1) + (pred_label == 1)),
|
||||
axis=1).astype(float)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ATTRMetric(nn.Layer):
|
||||
def __init__(self, threshold=0.5):
|
||||
super().__init__()
|
||||
self.threshold = threshold
|
||||
|
||||
def reset(self):
|
||||
self.attrmeter = AttrMeter(threshold=0.5)
|
||||
|
||||
def forward(self, output, target):
|
||||
metric_dict = get_attr_metrics(target[:, 0, :].numpy(),
|
||||
output.numpy(), self.threshold)
|
||||
self.attrmeter.update(metric_dict)
|
||||
return metric_dict
|
||||
|
|
|
@ -65,3 +65,87 @@ class AverageMeter(object):
|
|||
def value(self):
|
||||
return '{self.name}: {self.val:{self.fmt}}{self.postfix}'.format(
|
||||
self=self)
|
||||
|
||||
|
||||
class AttrMeter(object):
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Code was based on https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
|
||||
def __init__(self, threshold=0.5):
|
||||
self.threshold = threshold
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.gt_pos = 0
|
||||
self.gt_neg = 0
|
||||
self.true_pos = 0
|
||||
self.true_neg = 0
|
||||
self.false_pos = 0
|
||||
self.false_neg = 0
|
||||
|
||||
self.gt_pos_ins = []
|
||||
self.true_pos_ins = []
|
||||
self.intersect_pos = []
|
||||
self.union_pos = []
|
||||
|
||||
def update(self, metric_dict):
|
||||
self.gt_pos += metric_dict['gt_pos']
|
||||
self.gt_neg += metric_dict['gt_neg']
|
||||
self.true_pos += metric_dict['true_pos']
|
||||
self.true_neg += metric_dict['true_neg']
|
||||
self.false_pos += metric_dict['false_pos']
|
||||
self.false_neg += metric_dict['false_neg']
|
||||
|
||||
self.gt_pos_ins += metric_dict['gt_pos_ins'].tolist()
|
||||
self.true_pos_ins += metric_dict['true_pos_ins'].tolist()
|
||||
self.intersect_pos += metric_dict['intersect_pos'].tolist()
|
||||
self.union_pos += metric_dict['union_pos'].tolist()
|
||||
|
||||
def res(self):
|
||||
import numpy as np
|
||||
eps = 1e-20
|
||||
label_pos_recall = 1.0 * self.true_pos / (
|
||||
self.gt_pos + eps) # true positive
|
||||
label_neg_recall = 1.0 * self.true_neg / (
|
||||
self.gt_neg + eps) # true negative
|
||||
# mean accuracy
|
||||
label_ma = (label_pos_recall + label_neg_recall) / 2
|
||||
|
||||
label_pos_recall = np.mean(label_pos_recall)
|
||||
label_neg_recall = np.mean(label_neg_recall)
|
||||
label_prec = (self.true_pos / (self.true_pos + self.false_pos + eps))
|
||||
label_acc = (self.true_pos /
|
||||
(self.true_pos + self.false_pos + self.false_neg + eps))
|
||||
label_f1 = np.mean(2 * label_prec * label_pos_recall /
|
||||
(label_prec + label_pos_recall + eps))
|
||||
|
||||
ma = (np.mean(label_ma))
|
||||
|
||||
self.gt_pos_ins = np.array(self.gt_pos_ins)
|
||||
self.true_pos_ins = np.array(self.true_pos_ins)
|
||||
self.intersect_pos = np.array(self.intersect_pos)
|
||||
self.union_pos = np.array(self.union_pos)
|
||||
instance_acc = self.intersect_pos / (self.union_pos + eps)
|
||||
instance_prec = self.intersect_pos / (self.true_pos_ins + eps)
|
||||
instance_recall = self.intersect_pos / (self.gt_pos_ins + eps)
|
||||
instance_f1 = 2 * instance_prec * instance_recall / (
|
||||
instance_prec + instance_recall + eps)
|
||||
|
||||
instance_acc = np.mean(instance_acc)
|
||||
instance_prec = np.mean(instance_prec)
|
||||
instance_recall = np.mean(instance_recall)
|
||||
instance_f1 = 2 * instance_prec * instance_recall / (
|
||||
instance_prec + instance_recall + eps)
|
||||
|
||||
instance_acc = np.mean(instance_acc)
|
||||
instance_prec = np.mean(instance_prec)
|
||||
instance_recall = np.mean(instance_recall)
|
||||
instance_f1 = np.mean(instance_f1)
|
||||
|
||||
res = [
|
||||
ma, label_f1, label_pos_recall, label_neg_recall, instance_f1,
|
||||
instance_acc, instance_prec, instance_recall
|
||||
]
|
||||
return res
|
||||
|
|
|
@ -9,3 +9,4 @@ scipy
|
|||
scikit-learn==0.23.2
|
||||
gast==0.3.3
|
||||
faiss-cpu==1.7.1.post2
|
||||
easydict
|
||||
|
|
Loading…
Reference in New Issue