Some modifications
parent
7badd5bf6c
commit
169682002d
|
@ -374,10 +374,7 @@ class FaceAttribute(object):
|
||||||
gender_list = [["Male", "男性"], ["Female", "女性"]]
|
gender_list = [["Male", "男性"], ["Female", "女性"]]
|
||||||
age_list = [["Young", "年轻人"], ["Old", "老年人"]]
|
age_list = [["Young", "年轻人"], ["Old", "老年人"]]
|
||||||
batch_res = []
|
batch_res = []
|
||||||
if self.convert_cn:
|
index = 1 if self.convert_cn else 0
|
||||||
index = 1
|
|
||||||
else:
|
|
||||||
index = 0
|
|
||||||
for idx, res in enumerate(x):
|
for idx, res in enumerate(x):
|
||||||
res = res.tolist()
|
res = res.tolist()
|
||||||
label_res = []
|
label_res = []
|
||||||
|
|
|
@ -84,7 +84,6 @@ class PersonAttribute(object):
|
||||||
if isinstance(x, dict):
|
if isinstance(x, dict):
|
||||||
x = x['logits']
|
x = x['logits']
|
||||||
assert isinstance(x, paddle.Tensor)
|
assert isinstance(x, paddle.Tensor)
|
||||||
|
|
||||||
if file_names is not None:
|
if file_names is not None:
|
||||||
assert x.shape[0] == len(file_names)
|
assert x.shape[0] == len(file_names)
|
||||||
x = F.sigmoid(x).numpy()
|
x = F.sigmoid(x).numpy()
|
||||||
|
@ -99,7 +98,6 @@ class PersonAttribute(object):
|
||||||
'Skirt&Dress'
|
'Skirt&Dress'
|
||||||
]
|
]
|
||||||
batch_res = []
|
batch_res = []
|
||||||
|
|
||||||
for idx, res in enumerate(x):
|
for idx, res in enumerate(x):
|
||||||
res = res.tolist()
|
res = res.tolist()
|
||||||
label_res = []
|
label_res = []
|
||||||
|
@ -209,10 +207,7 @@ class FaceAttribute(object):
|
||||||
gender_list = [["Male", "男性"], ["Female", "女性"]]
|
gender_list = [["Male", "男性"], ["Female", "女性"]]
|
||||||
age_list = [["Young", "年轻人"], ["Old", "老年人"]]
|
age_list = [["Young", "年轻人"], ["Old", "老年人"]]
|
||||||
batch_res = []
|
batch_res = []
|
||||||
if self.convert_cn:
|
index = 1 if self.convert_cn else 0
|
||||||
index = 1
|
|
||||||
else:
|
|
||||||
index = 0
|
|
||||||
for idx, res in enumerate(x):
|
for idx, res in enumerate(x):
|
||||||
res = res.tolist()
|
res = res.tolist()
|
||||||
label_res = []
|
label_res = []
|
||||||
|
|
|
@ -219,12 +219,17 @@ class TprAtFpr(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class MultilabelMeanAccuracy(nn.Layer):
|
class MultilabelMeanAccuracy(nn.Layer):
|
||||||
def __init__(self, class_num=40):
|
def __init__(self,
|
||||||
|
start_threshold=0.4,
|
||||||
|
num_iterations=10,
|
||||||
|
end_threshold=0.9):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.start_threshold = start_threshold
|
||||||
|
self.num_iterations = num_iterations
|
||||||
|
self.end_threshold = end_threshold
|
||||||
self.gt_all_score_list = []
|
self.gt_all_score_list = []
|
||||||
self.gt_label_score_list = []
|
self.gt_label_score_list = []
|
||||||
self.max_acc = 0.
|
self.max_acc = 0.
|
||||||
self.class_num = class_num
|
|
||||||
|
|
||||||
def forward(self, x, label):
|
def forward(self, x, label):
|
||||||
if isinstance(x, dict):
|
if isinstance(x, dict):
|
||||||
|
@ -251,8 +256,10 @@ class MultilabelMeanAccuracy(nn.Layer):
|
||||||
result = ""
|
result = ""
|
||||||
gt_all_score_list = np.array(self.gt_all_score_list)
|
gt_all_score_list = np.array(self.gt_all_score_list)
|
||||||
gt_label_score_list = np.array(self.gt_label_score_list)
|
gt_label_score_list = np.array(self.gt_label_score_list)
|
||||||
for i in range(10):
|
for i in range(self.num_iterations):
|
||||||
threshold = 0.4 + i * 0.05
|
threshold = self.start_threshold + i * (self.end_threshold -
|
||||||
|
self.start_threshold
|
||||||
|
) / self.num_iterations
|
||||||
pred_label = (gt_all_score_list > threshold).astype(int)
|
pred_label = (gt_all_score_list > threshold).astype(int)
|
||||||
TP = np.sum(
|
TP = np.sum(
|
||||||
(gt_label_score_list == 1) * (pred_label == 1)).astype(float)
|
(gt_label_score_list == 1) * (pred_label == 1)).astype(float)
|
||||||
|
@ -262,8 +269,8 @@ class MultilabelMeanAccuracy(nn.Layer):
|
||||||
if max_acc <= acc:
|
if max_acc <= acc:
|
||||||
max_acc = acc
|
max_acc = acc
|
||||||
result = "threshold: {}, mean_acc: {}".format(
|
result = "threshold: {}, mean_acc: {}".format(
|
||||||
threshold, max_acc / self.class_num)
|
threshold, max_acc / len(gt_label_score_list[0]))
|
||||||
self.max_acc = max_acc / self.class_num
|
self.max_acc = max_acc / len(gt_label_score_list[0])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue