Merge pull request #956 from RainFrost1/develop

fix mAP bugs
pull/961/head
Wei Shengyu 2021-06-23 18:08:59 +08:00 committed by GitHub
commit 4fb84dabbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 9 deletions

View File

@ -41,7 +41,7 @@ class mAP(nn.Layer):
super().__init__()
def forward(self, similarities_matrix, query_img_id, gallery_img_id,
*args):
keep_mask):
metric_dict = dict()
choosen_indices = paddle.argsort(
@ -55,8 +55,19 @@ class mAP(nn.Layer):
choosen_label = paddle.index_sample(gallery_labels_transpose,
choosen_indices)
equal_flag = paddle.equal(choosen_label, query_img_id)
if keep_mask is not None:
keep_mask = paddle.index_sample(
keep_mask.astype('float32'), choosen_indices)
equal_flag = paddle.logical_and(equal_flag,
keep_mask.astype('bool'))
equal_flag = paddle.cast(equal_flag, 'float32')
num_rel = paddle.sum(equal_flag, axis=1)
num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
num_rel_index = paddle.nonzero(num_rel.astype("int"))
num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)
acc_sum = paddle.cumsum(equal_flag, axis=1)
div = paddle.arange(acc_sum.shape[1]).astype("float32") + 1
precision = paddle.divide(acc_sum, div)
@ -74,7 +85,7 @@ class mINP(nn.Layer):
super().__init__()
def forward(self, similarities_matrix, query_img_id, gallery_img_id,
*args):
keep_mask):
metric_dict = dict()
choosen_indices = paddle.argsort(
@ -87,15 +98,26 @@ class mINP(nn.Layer):
])
choosen_label = paddle.index_sample(gallery_labels_transpose,
choosen_indices)
tmp = paddle.equal(choosen_label, query_img_id)
tmp = paddle.cast(tmp, 'float64')
equal_flag = paddle.equal(choosen_label, query_img_id)
if keep_mask is not None:
keep_mask = paddle.index_sample(
keep_mask.astype('float32'), choosen_indices)
equal_flag = paddle.logical_and(equal_flag,
keep_mask.astype('bool'))
equal_flag = paddle.cast(equal_flag, 'float32')
num_rel = paddle.sum(equal_flag, axis=1)
num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
num_rel_index = paddle.nonzero(num_rel.astype("int"))
num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]])
equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)
#do accumulative sum
div = paddle.arange(tmp.shape[1]).astype("float64") + 2
minus = paddle.divide(tmp, div)
auxilary = paddle.subtract(tmp, minus)
hard_index = paddle.argmax(auxilary, axis=1).astype("float64")
all_INP = paddle.divide(paddle.sum(tmp, axis=1), hard_index)
div = paddle.arange(equal_flag.shape[1]).astype("float32") + 2
minus = paddle.divide(equal_flag, div)
auxilary = paddle.subtract(equal_flag, minus)
hard_index = paddle.argmax(auxilary, axis=1).astype("float32")
all_INP = paddle.divide(paddle.sum(equal_flag, axis=1), hard_index)
mINP = paddle.mean(all_INP)
metric_dict["mINP"] = mINP.numpy()[0]
return metric_dict