mirror of https://github.com/JDAI-CV/fast-reid.git
optimize jaccard distance computation and the ranking
parent
10a5f38aaa
commit
8969f1bd3a
|
@ -81,8 +81,7 @@ def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||||
# compute AP
|
# compute AP
|
||||||
num_rel = raw_cmc.sum()
|
num_rel = raw_cmc.sum()
|
||||||
tmp_cmc = raw_cmc.cumsum()
|
tmp_cmc = raw_cmc.cumsum()
|
||||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
tmp_cmc = (tmp_cmc/np.arange(1,len(tmp_cmc)+1)) * raw_cmc
|
||||||
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
|
||||||
AP = tmp_cmc.sum() / num_rel
|
AP = tmp_cmc.sum() / num_rel
|
||||||
all_AP.append(AP)
|
all_AP.append(AP)
|
||||||
num_valid_q += 1.
|
num_valid_q += 1.
|
||||||
|
|
|
@ -50,8 +50,8 @@ def build_dist(feat_1: torch.Tensor, feat_2: torch.Tensor, metric: str = "euclid
|
||||||
|
|
||||||
elif metric == "jaccard":
|
elif metric == "jaccard":
|
||||||
feat = torch.cat((feat_1, feat_2), dim=0)
|
feat = torch.cat((feat_1, feat_2), dim=0)
|
||||||
dist = compute_jaccard_distance(feat, k1=kwargs["k1"], k2=kwargs["k2"], search_option=0)
|
dist = compute_jaccard_distance(feat,feat_1.shape[0],feat_2.shape[0], k1=kwargs["k1"], k2=kwargs["k2"], search_option=0)
|
||||||
return dist[: feat_1.size(0), feat_1.size(0):]
|
return dist
|
||||||
|
|
||||||
|
|
||||||
def k_reciprocal_neigh(initial_rank, i, k1):
|
def k_reciprocal_neigh(initial_rank, i, k1):
|
||||||
|
@ -62,7 +62,7 @@ def k_reciprocal_neigh(initial_rank, i, k1):
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def compute_jaccard_distance(features, k1=20, k2=6, search_option=0, fp16=False):
|
def compute_jaccard_distance(features, N_feat_1, N_feat_2, k1=20, k2=6, search_option=0, fp16=False):
|
||||||
if search_option < 3:
|
if search_option < 3:
|
||||||
# torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
features = features.cuda()
|
features = features.cuda()
|
||||||
|
@ -153,17 +153,18 @@ def compute_jaccard_distance(features, k1=20, k2=6, search_option=0, fp16=False)
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
invIndex.append(np.where(V[:, i] != 0)[0]) # len(invIndex)=all_num
|
invIndex.append(np.where(V[:, i] != 0)[0]) # len(invIndex)=all_num
|
||||||
|
|
||||||
jaccard_dist = np.zeros((N, N), dtype=mat_type)
|
jaccard_dist = np.zeros((N_feat_1, N_feat_2), dtype=mat_type)
|
||||||
for i in range(N):
|
for i in range(N_feat_1):
|
||||||
temp_min = np.zeros((1, N), dtype=mat_type)
|
temp_min = np.zeros((1, N), dtype=mat_type)
|
||||||
indNonZero = np.where(V[i, :] != 0)[0]
|
indNonZero = np.where(V[i, :] != 0)[0]
|
||||||
indImages = [invIndex[ind] for ind in indNonZero]
|
indImages = [invIndex[ind] for ind in indNonZero]
|
||||||
for j in range(len(indNonZero)):
|
for j in range(len(indNonZero)):
|
||||||
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(
|
if indImages[j]>=N_feat_1:
|
||||||
V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]
|
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(
|
||||||
)
|
V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]
|
||||||
|
)
|
||||||
|
|
||||||
jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
|
jaccard_dist[i] = (1 - temp_min / (2 - temp_min))[:,N_feat_1:]
|
||||||
|
|
||||||
del invIndex, V
|
del invIndex, V
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue