add re-ranking code
parent
283ae9b327
commit
88295413f5
|
@ -12,6 +12,7 @@ Global:
|
|||
use_visualdl: False
|
||||
eval_mode: "retrieval"
|
||||
retrieval_feature_from: "backbone" # 'backbone' or 'neck'
|
||||
re_ranking: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 256, 128]
|
||||
save_inference_dir: "./inference"
|
||||
|
|
|
@ -12,6 +12,7 @@ Global:
|
|||
use_visualdl: False
|
||||
eval_mode: "retrieval"
|
||||
retrieval_feature_from: "features" # 'backbone' or 'features'
|
||||
re_ranking: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 256, 128]
|
||||
save_inference_dir: "./inference"
|
||||
|
|
|
@ -12,6 +12,7 @@ Global:
|
|||
use_visualdl: False
|
||||
eval_mode: "retrieval"
|
||||
retrieval_feature_from: "features" # 'backbone' or 'features'
|
||||
re_ranking: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 256, 128]
|
||||
save_inference_dir: "./inference"
|
||||
|
|
|
@ -16,6 +16,9 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
@ -48,34 +51,68 @@ def retrieval_eval(engine, epoch_id=0):
|
|||
if engine.eval_loss_func is None:
|
||||
metric_dict = {metric_key: 0.}
|
||||
else:
|
||||
reranking_flag = engine.config['Global'].get('re_ranking', False)
|
||||
logger.info(f"re_ranking={reranking_flag}")
|
||||
metric_dict = dict()
|
||||
for block_idx, block_fea in enumerate(fea_blocks):
|
||||
similarity_matrix = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True)
|
||||
if query_query_id is not None:
|
||||
query_id_block = query_id_blocks[block_idx]
|
||||
query_id_mask = (query_id_block != gallery_unique_id.t())
|
||||
if reranking_flag:
|
||||
# set the order from small to large
|
||||
for i in range(len(engine.eval_metric_func.metric_func_list)):
|
||||
if hasattr(engine.eval_metric_func.metric_func_list[i], 'descending') \
|
||||
and engine.eval_metric_func.metric_func_list[i].descending is True:
|
||||
engine.eval_metric_func.metric_func_list[
|
||||
i].descending = False
|
||||
logger.info(
|
||||
f"set {engine.eval_metric_func.metric_func_list[i].__class__.__name__}.descending to False when re_ranking=True"
|
||||
)
|
||||
|
||||
image_id_block = image_id_blocks[block_idx]
|
||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||
# compute distance matrix(The smaller the value, the more similar)
|
||||
distmat = re_ranking(
|
||||
query_feas, gallery_feas, k1=20, k2=6, lambda_value=0.3)
|
||||
distmat = paddle.to_tensor(distmat)
|
||||
|
||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||
similarity_matrix = similarity_matrix * keep_mask.astype(
|
||||
"float32")
|
||||
else:
|
||||
keep_mask = None
|
||||
# compute keep mask
|
||||
query_id_mask = (query_query_id != gallery_unique_id.t())
|
||||
image_id_mask = (query_img_id != gallery_img_id.t())
|
||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||
|
||||
metric_tmp = engine.eval_metric_func(similarity_matrix,
|
||||
image_id_blocks[block_idx],
|
||||
# set inf(1e9) distance to those exist in gallery
|
||||
distmat = distmat * keep_mask.astype("float32")
|
||||
inf_mat = (paddle.logical_not(keep_mask).astype("float32")) * 1e20
|
||||
distmat = distmat + inf_mat
|
||||
|
||||
# compute metric
|
||||
metric_tmp = engine.eval_metric_func(distmat, query_img_id,
|
||||
gallery_img_id, keep_mask)
|
||||
|
||||
for key in metric_tmp:
|
||||
if key not in metric_dict:
|
||||
metric_dict[key] = metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
metric_dict[key] = metric_tmp[key]
|
||||
else:
|
||||
for block_idx, block_fea in enumerate(fea_blocks):
|
||||
similarity_matrix = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True) # [n,m]
|
||||
if query_query_id is not None:
|
||||
query_id_block = query_id_blocks[block_idx]
|
||||
query_id_mask = (query_id_block != gallery_unique_id.t())
|
||||
|
||||
image_id_block = image_id_blocks[block_idx]
|
||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||
|
||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||
similarity_matrix = similarity_matrix * keep_mask.astype(
|
||||
"float32")
|
||||
else:
|
||||
metric_dict[key] += metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
keep_mask = None
|
||||
|
||||
metric_tmp = engine.eval_metric_func(
|
||||
similarity_matrix, image_id_blocks[block_idx],
|
||||
gallery_img_id, keep_mask)
|
||||
|
||||
for key in metric_tmp:
|
||||
if key not in metric_dict:
|
||||
metric_dict[key] = metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
else:
|
||||
metric_dict[key] += metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
|
||||
metric_info_list = []
|
||||
for key in metric_dict:
|
||||
|
@ -185,3 +222,109 @@ def cal_feature(engine, name='gallery'):
|
|||
logger.info("Build {} done, all feat shape: {}, begin to eval..".format(
|
||||
name, all_feas.shape))
|
||||
return all_feas, all_img_id, all_unique_id
|
||||
|
||||
|
||||
def re_ranking(query_feas: paddle.Tensor,
|
||||
gallery_feas: paddle.Tensor,
|
||||
k1: int=20,
|
||||
k2: int=6,
|
||||
lambda_value: int=0.5,
|
||||
local_distmat: Optional[np.ndarray]=None,
|
||||
only_local: bool=False) -> paddle.Tensor:
|
||||
"""re-ranking, most computed with numpy
|
||||
|
||||
code heavily based on
|
||||
https://github.com/michuanhaohao/reid-strong-baseline/blob/3da7e6f03164a92e696cb6da059b1cd771b0346d/utils/reid_metric.py
|
||||
|
||||
Args:
|
||||
query_feas (paddle.Tensor): query features, [num_query, num_features]
|
||||
gallery_feas (paddle.Tensor): gallery features, [num_gallery, num_features]
|
||||
k1 (int, optional): k1. Defaults to 20.
|
||||
k2 (int, optional): k2. Defaults to 6.
|
||||
lambda_value (int, optional): lambda. Defaults to 0.5.
|
||||
local_distmat (Optional[np.ndarray], optional): local_distmat. Defaults to None.
|
||||
only_local (bool, optional): only_local. Defaults to False.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: final_dist matrix after re-ranking, [num_query, num_gallery]
|
||||
"""
|
||||
query_num = query_feas.shape[0]
|
||||
all_num = query_num + gallery_feas.shape[0]
|
||||
if only_local:
|
||||
original_dist = local_distmat
|
||||
else:
|
||||
feat = paddle.concat([query_feas, gallery_feas])
|
||||
logger.info('using GPU to compute original distance')
|
||||
|
||||
# L2 distance
|
||||
distmat = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]) + \
|
||||
paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]).t()
|
||||
distmat = distmat.addmm(x=feat, y=feat.t(), alpha=-2.0, beta=1.0)
|
||||
|
||||
original_dist = distmat.cpu().numpy()
|
||||
del feat
|
||||
if local_distmat is not None:
|
||||
original_dist = original_dist + local_distmat
|
||||
|
||||
gallery_num = original_dist.shape[0]
|
||||
original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
|
||||
V = np.zeros_like(original_dist).astype(np.float16)
|
||||
initial_rank = np.argsort(original_dist).astype(np.int32)
|
||||
logger.info('starting re_ranking')
|
||||
for i in range(all_num):
|
||||
# k-reciprocal neighbors
|
||||
forward_k_neigh_index = initial_rank[i, :k1 + 1]
|
||||
backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
|
||||
fi = np.where(backward_k_neigh_index == i)[0]
|
||||
k_reciprocal_index = forward_k_neigh_index[fi]
|
||||
k_reciprocal_expansion_index = k_reciprocal_index
|
||||
for j in range(len(k_reciprocal_index)):
|
||||
candidate = k_reciprocal_index[j]
|
||||
candidate_forward_k_neigh_index = initial_rank[candidate, :int(
|
||||
np.around(k1 / 2)) + 1]
|
||||
candidate_backward_k_neigh_index = initial_rank[
|
||||
candidate_forward_k_neigh_index, :int(np.around(k1 / 2)) + 1]
|
||||
fi_candidate = np.where(
|
||||
candidate_backward_k_neigh_index == candidate)[0]
|
||||
candidate_k_reciprocal_index = candidate_forward_k_neigh_index[
|
||||
fi_candidate]
|
||||
if len(
|
||||
np.intersect1d(candidate_k_reciprocal_index,
|
||||
k_reciprocal_index)) > 2 / 3 * len(
|
||||
candidate_k_reciprocal_index):
|
||||
k_reciprocal_expansion_index = np.append(
|
||||
k_reciprocal_expansion_index, candidate_k_reciprocal_index)
|
||||
|
||||
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
|
||||
weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
|
||||
V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
|
||||
original_dist = original_dist[:query_num, ]
|
||||
if k2 != 1:
|
||||
V_qe = np.zeros_like(V, dtype=np.float16)
|
||||
for i in range(all_num):
|
||||
V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
|
||||
V = V_qe
|
||||
del V_qe
|
||||
del initial_rank
|
||||
invIndex = []
|
||||
for i in range(gallery_num):
|
||||
invIndex.append(np.where(V[:, i] != 0)[0])
|
||||
|
||||
jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
|
||||
for i in range(query_num):
|
||||
temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
|
||||
indNonZero = np.where(V[i, :] != 0)[0]
|
||||
indImages = [invIndex[ind] for ind in indNonZero]
|
||||
for j in range(len(indNonZero)):
|
||||
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(
|
||||
V[i, indNonZero[j]], V[indImages[j], indNonZero[j]])
|
||||
jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
|
||||
|
||||
final_dist = jaccard_dist * (1 - lambda_value
|
||||
) + original_dist * lambda_value
|
||||
del original_dist
|
||||
del V
|
||||
del jaccard_dist
|
||||
final_dist = final_dist[:query_num, query_num:]
|
||||
final_dist = paddle.to_tensor(final_dist)
|
||||
return final_dist
|
||||
|
|
|
@ -43,15 +43,16 @@ class TopkAcc(nn.Layer):
|
|||
|
||||
|
||||
class mAP(nn.Layer):
|
||||
def __init__(self):
|
||||
def __init__(self, descending=True):
|
||||
super().__init__()
|
||||
self.descending = descending
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id,
|
||||
keep_mask):
|
||||
metric_dict = dict()
|
||||
|
||||
choosen_indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True)
|
||||
similarities_matrix, axis=1, descending=self.descending)
|
||||
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
|
||||
gallery_labels_transpose = paddle.broadcast_to(
|
||||
gallery_labels_transpose,
|
||||
|
@ -87,15 +88,16 @@ class mAP(nn.Layer):
|
|||
|
||||
|
||||
class mINP(nn.Layer):
|
||||
def __init__(self):
|
||||
def __init__(self, descending=True):
|
||||
super().__init__()
|
||||
self.descending = descending
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id,
|
||||
keep_mask):
|
||||
metric_dict = dict()
|
||||
|
||||
choosen_indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True)
|
||||
similarities_matrix, axis=1, descending=self.descending)
|
||||
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
|
||||
gallery_labels_transpose = paddle.broadcast_to(
|
||||
gallery_labels_transpose,
|
||||
|
@ -130,12 +132,13 @@ class mINP(nn.Layer):
|
|||
|
||||
|
||||
class Recallk(nn.Layer):
|
||||
def __init__(self, topk=(1, 5)):
|
||||
def __init__(self, topk=(1, 5), descending=True):
|
||||
super().__init__()
|
||||
assert isinstance(topk, (int, list, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
self.descending = descending
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id,
|
||||
keep_mask):
|
||||
|
@ -143,7 +146,7 @@ class Recallk(nn.Layer):
|
|||
|
||||
#get cmc
|
||||
choosen_indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True)
|
||||
similarities_matrix, axis=1, descending=self.descending)
|
||||
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
|
||||
gallery_labels_transpose = paddle.broadcast_to(
|
||||
gallery_labels_transpose,
|
||||
|
@ -175,12 +178,13 @@ class Recallk(nn.Layer):
|
|||
|
||||
|
||||
class Precisionk(nn.Layer):
|
||||
def __init__(self, topk=(1, 5)):
|
||||
def __init__(self, topk=(1, 5), descending=True):
|
||||
super().__init__()
|
||||
assert isinstance(topk, (int, list, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
self.descending = descending
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id,
|
||||
keep_mask):
|
||||
|
@ -188,7 +192,7 @@ class Precisionk(nn.Layer):
|
|||
|
||||
#get cmc
|
||||
choosen_indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True)
|
||||
similarities_matrix, axis=1, descending=self.descending)
|
||||
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
|
||||
gallery_labels_transpose = paddle.broadcast_to(
|
||||
gallery_labels_transpose,
|
||||
|
|
Loading…
Reference in New Issue