refactor(retrieval): polish retrieval.py
parent
97f99cd826
commit
c6865e255e
|
@ -16,108 +16,90 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from ppcls.engine.train.utils import type_name
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils import all_gather
|
||||
|
||||
from ppcls.utils import all_gather, logger
|
||||
|
||||
|
||||
def retrieval_eval(engine, epoch_id=0):
|
||||
engine.model.eval()
|
||||
# step1. prepare query and gallery features
|
||||
if engine.gallery_query_dataloader is not None:
|
||||
gallery_feas, gallery_label_id, gallery_camera_id = compute_feature(
|
||||
gallery_feat, gallery_label, gallery_camera = compute_feature(
|
||||
engine, "gallery_query")
|
||||
query_feas, query_label_id, query_camera_id = gallery_feas, gallery_label_id, gallery_camera_id
|
||||
query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
|
||||
else:
|
||||
gallery_feas, gallery_label_id, gallery_camera_id = compute_feature(
|
||||
gallery_feat, gallery_label, gallery_camera = compute_feature(
|
||||
engine, "gallery")
|
||||
query_feas, query_label_id, query_camera_id = compute_feature(engine,
|
||||
"query")
|
||||
|
||||
query_feat, query_label, query_camera = compute_feature(engine,
|
||||
"query")
|
||||
# step2. split features into feature blocks for saving memory
|
||||
num_query = len(query_feat)
|
||||
block_size = engine.config["Global"].get("sim_block_size", 64)
|
||||
sections = [block_size] * (len(query_feas) // block_size)
|
||||
if len(query_feas) % block_size > 0:
|
||||
sections.append(len(query_feas) % block_size)
|
||||
sections = [block_size] * (num_query // block_size)
|
||||
if num_query % block_size > 0:
|
||||
sections.append(num_query % block_size)
|
||||
|
||||
query_feas_blocks = paddle.split(query_feas, sections)
|
||||
query_camera_id_blocks = (paddle.split(query_camera_id, sections)
|
||||
if query_camera_id is not None else None)
|
||||
query_label_id_blocks = paddle.split(query_label_id, sections)
|
||||
query_feat_blocks = paddle.split(query_feat, sections)
|
||||
query_label_blocks = paddle.split(query_label, sections)
|
||||
query_camera_blocks = paddle.split(
|
||||
query_camera, sections) if query_camera is not None else None
|
||||
metric_key = None
|
||||
|
||||
# step3. compute metric
|
||||
if engine.eval_loss_func is None:
|
||||
metric_dict = {metric_key: 0.}
|
||||
metric_dict = {metric_key: 0.0}
|
||||
else:
|
||||
use_reranking = engine.config["Global"].get("re_ranking", False)
|
||||
logger.info(f"re_ranking={use_reranking}")
|
||||
metric_dict = {}
|
||||
if use_reranking:
|
||||
for _, metric_func in enumerate(
|
||||
engine.eval_metric_func.metric_func_list):
|
||||
if hasattr(metric_func,
|
||||
"descending") and metric_func.descending is True:
|
||||
metric_func.descending = False
|
||||
logger.warning(
|
||||
f"re_ranking=True, set {type_name(metric_func)}.descending set to False"
|
||||
)
|
||||
# compute distance matrix
|
||||
distmat = compute_re_ranking_dist(
|
||||
query_feas, gallery_feas, engine.config["Global"].get(
|
||||
query_feat, gallery_feat, engine.config["Global"].get(
|
||||
"feature_normalize", True), 20, 6, 0.3)
|
||||
|
||||
# exclude illegal distance
|
||||
camera_id_mask = query_camera_id != gallery_camera_id.t()
|
||||
image_id_mask = query_label_id != gallery_label_id.t()
|
||||
keep_mask = paddle.logical_or(image_id_mask, camera_id_mask)
|
||||
distmat = distmat * keep_mask.astype(query_feas.dtype)
|
||||
inf_mat = (
|
||||
paddle.logical_not(keep_mask).astype(query_feas.dtype)) * (
|
||||
distmat.max() + 1)
|
||||
distmat = distmat + inf_mat
|
||||
|
||||
metric_block = engine.eval_metric_func(distmat, query_label_id,
|
||||
gallery_label_id, keep_mask)
|
||||
for key in metric_block:
|
||||
metric_dict[key] = metric_block[key]
|
||||
if query_camera is not None:
|
||||
camera_mask = query_camera != gallery_camera.t()
|
||||
label_mask = query_label != gallery_label.t()
|
||||
keep_mask = label_mask | camera_mask
|
||||
distmat = keep_mask.astype(query_feat.dtype) * distmat + (
|
||||
~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1)
|
||||
else:
|
||||
keep_mask = None
|
||||
# compute metric with all samples
|
||||
metric_dict = engine.eval_metric_func(-distmat, query_label,
|
||||
gallery_label, keep_mask)
|
||||
else:
|
||||
for block_idx, block_fea in enumerate(query_feas_blocks):
|
||||
metric_dict = defaultdict(float)
|
||||
for block_idx, block_feat in enumerate(query_feat_blocks):
|
||||
# compute distance matrix
|
||||
distmat = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True)
|
||||
if query_camera_id is not None:
|
||||
query_camera_id_block = query_camera_id_blocks[block_idx]
|
||||
camera_id_mask = query_camera_id_block != gallery_camera_id.t(
|
||||
)
|
||||
|
||||
query_label_id_block = query_label_id_blocks[block_idx]
|
||||
image_id_mask = query_label_id_block != gallery_label_id.t(
|
||||
)
|
||||
|
||||
keep_mask = paddle.logical_or(image_id_mask,
|
||||
camera_id_mask)
|
||||
distmat = distmat * keep_mask.astype("float32")
|
||||
block_feat, gallery_feat, transpose_y=True)
|
||||
# exclude illegal distance
|
||||
if query_camera is not None:
|
||||
camera_mask = query_camera_blocks[
|
||||
block_idx] != gallery_camera.t()
|
||||
label_mask = query_label_blocks[
|
||||
block_idx] != gallery_label.t()
|
||||
keep_mask = label_mask | camera_mask
|
||||
distmat = keep_mask.astype(query_feat.dtype) * distmat
|
||||
else:
|
||||
keep_mask = None
|
||||
|
||||
# compute metric by block
|
||||
metric_block = engine.eval_metric_func(
|
||||
distmat, query_label_id_blocks[block_idx],
|
||||
gallery_label_id, keep_mask)
|
||||
|
||||
distmat, query_label_blocks[block_idx], gallery_label,
|
||||
keep_mask)
|
||||
# accumulate metric
|
||||
for key in metric_block:
|
||||
if key not in metric_dict:
|
||||
metric_dict[key] = metric_block[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
else:
|
||||
metric_dict[key] += metric_block[
|
||||
key] * block_fea.shape[0] / len(query_feas)
|
||||
metric_dict[key] += metric_block[key] * block_feat.shape[
|
||||
0] / num_query
|
||||
|
||||
metric_info_list = []
|
||||
for key in metric_dict:
|
||||
metric_info_list.append(f"{key}: {metric_dict[key]:.5f}")
|
||||
for key, value in metric_dict.items():
|
||||
metric_info_list.append(f"{key}: {value:.5f}")
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
metric_msg = ", ".join(metric_info_list)
|
||||
|
@ -127,9 +109,6 @@ def retrieval_eval(engine, epoch_id=0):
|
|||
|
||||
|
||||
def compute_feature(engine, name="gallery"):
|
||||
has_camera_id = False
|
||||
all_camera_id = None
|
||||
|
||||
if name == "gallery":
|
||||
dataloader = engine.gallery_dataloader
|
||||
elif name == "query":
|
||||
|
@ -137,13 +116,16 @@ def compute_feature(engine, name="gallery"):
|
|||
elif name == "gallery_query":
|
||||
dataloader = engine.gallery_query_dataloader
|
||||
else:
|
||||
raise RuntimeError("Only support gallery or query dataset")
|
||||
raise ValueError(
|
||||
f"Only support gallery or query or gallery_query dataset, but got {name}"
|
||||
)
|
||||
|
||||
batch_feas_list = []
|
||||
label_id_list = []
|
||||
camera_id_list = []
|
||||
all_feat = []
|
||||
all_label = []
|
||||
all_camera = []
|
||||
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
|
||||
dataloader)
|
||||
has_camera = False
|
||||
for idx, batch in enumerate(dataloader): # load is very time-consuming
|
||||
if idx >= max_iter:
|
||||
break
|
||||
|
@ -154,8 +136,8 @@ def compute_feature(engine, name="gallery"):
|
|||
|
||||
batch = [paddle.to_tensor(x) for x in batch]
|
||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||
if len(batch) == 3:
|
||||
has_camera_id = True
|
||||
if len(batch) >= 3:
|
||||
has_camera = True
|
||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||
if engine.amp and engine.amp_eval:
|
||||
with paddle.amp.auto_cast(
|
||||
|
@ -163,62 +145,61 @@ def compute_feature(engine, name="gallery"):
|
|||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=engine.amp_level):
|
||||
out = engine.model(batch[0], batch[1])
|
||||
out = engine.model(batch[0])
|
||||
else:
|
||||
out = engine.model(batch[0], batch[1])
|
||||
out = engine.model(batch[0])
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
|
||||
# get features
|
||||
if engine.config["Global"].get("retrieval_feature_from",
|
||||
"features") == "features":
|
||||
# use neck's output as features
|
||||
batch_feas = out["features"]
|
||||
# use output from neck as feature
|
||||
batch_feat = out["features"]
|
||||
else:
|
||||
# use backbone's output as features
|
||||
batch_feas = out["backbone"]
|
||||
# use output from backbone as feature
|
||||
batch_feat = out["backbone"]
|
||||
|
||||
# do norm(optinal)
|
||||
# do norm(optional)
|
||||
if engine.config["Global"].get("feature_normalize", True):
|
||||
batch_feas_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
|
||||
batch_feas = paddle.divide(batch_feas, batch_feas_norm)
|
||||
batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
|
||||
|
||||
# do binarize(optinal)
|
||||
# do binarize(optional)
|
||||
if engine.config["Global"].get("feature_binarize") == "round":
|
||||
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0
|
||||
batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0
|
||||
elif engine.config["Global"].get("feature_binarize") == "sign":
|
||||
batch_feas = paddle.sign(batch_feas).astype("float32")
|
||||
batch_feat = paddle.sign(batch_feat).astype("float32")
|
||||
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
batch_feas_list.append(all_gather(batch_feas))
|
||||
label_id_list.append(all_gather(batch[1]))
|
||||
if has_camera_id:
|
||||
camera_id_list.append(all_gather(batch[2]))
|
||||
all_feat.append(all_gather(batch_feat))
|
||||
all_label.append(all_gather(batch[1]))
|
||||
if has_camera:
|
||||
all_camera.append(all_gather(batch[2]))
|
||||
else:
|
||||
batch_feas_list.append(batch_feas)
|
||||
label_id_list.append(batch[1])
|
||||
if has_camera_id:
|
||||
camera_id_list.append(batch[2])
|
||||
all_feat.append(batch_feat)
|
||||
all_label.append(batch[1])
|
||||
if has_camera:
|
||||
all_camera.append(batch[2])
|
||||
|
||||
if engine.use_dali:
|
||||
dataloader.reset()
|
||||
|
||||
all_feas = paddle.concat(batch_feas_list)
|
||||
all_label_id = paddle.concat(label_id_list)
|
||||
if has_camera_id:
|
||||
all_camera_id = paddle.concat(camera_id_list)
|
||||
|
||||
all_feat = paddle.concat(all_feat)
|
||||
all_label = paddle.concat(all_label)
|
||||
if has_camera:
|
||||
all_camera = paddle.concat(all_camera)
|
||||
else:
|
||||
all_camera = None
|
||||
# discard redundant padding sample(s) at the end
|
||||
total_samples = len(
|
||||
dataloader.dataset) if not engine.use_dali else dataloader.size
|
||||
all_feas = all_feas[:total_samples]
|
||||
all_label_id = all_label_id[:total_samples]
|
||||
if has_camera_id:
|
||||
all_camera_id = all_camera_id[:total_samples]
|
||||
total_samples = dataloader.size if engine.use_dali else len(
|
||||
dataloader.dataset)
|
||||
all_feat = all_feat[:total_samples]
|
||||
all_label = all_label[:total_samples]
|
||||
if has_camera:
|
||||
all_camera = all_camera[:total_samples]
|
||||
|
||||
logger.info(f"Build {name} done, all feat shape: {all_feas.shape}")
|
||||
return all_feas, all_label_id, all_camera_id
|
||||
logger.info(f"Build {name} done, all feat shape: {all_feat.shape}")
|
||||
return all_feat, all_label, all_camera
|
||||
|
||||
|
||||
def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
|
||||
|
@ -239,8 +220,8 @@ def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
|
|||
return forward_k_neigh_index[candidate]
|
||||
|
||||
|
||||
def compute_re_ranking_dist(query_feas: paddle.Tensor,
|
||||
gallery_feas: paddle.Tensor,
|
||||
def compute_re_ranking_dist(query_feat: paddle.Tensor,
|
||||
gallery_feat: paddle.Tensor,
|
||||
feature_normed: bool=True,
|
||||
k1: int=20,
|
||||
k2: int=6,
|
||||
|
@ -251,8 +232,8 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
|
|||
Code refernence: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py
|
||||
|
||||
Args:
|
||||
query_feas (paddle.Tensor): Query features with shape of [num_query, feature_dim].
|
||||
gallery_feas (paddle.Tensor): Gallery features with shape of [num_gallery, feature_dim].
|
||||
query_feat (paddle.Tensor): Query features with shape of [num_query, feature_dim].
|
||||
gallery_feat (paddle.Tensor): Gallery features with shape of [num_gallery, feature_dim].
|
||||
feature_normed (bool, optional): Whether input features are normalized.
|
||||
k1 (int, optional): Parameter for K-reciprocal nearest neighbors. Defaults to 20.
|
||||
k2 (int, optional): Parameter for K-nearest neighbors. Defaults to 6.
|
||||
|
@ -261,10 +242,10 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
|
|||
Returns:
|
||||
paddle.Tensor: (1 - lamb) x Dj + lamb x D, with shape of [num_query, num_gallery].
|
||||
"""
|
||||
num_query = query_feas.shape[0]
|
||||
num_gallery = gallery_feas.shape[0]
|
||||
num_query = query_feat.shape[0]
|
||||
num_gallery = gallery_feat.shape[0]
|
||||
num_all = num_query + num_gallery
|
||||
feat = paddle.concat([query_feas, gallery_feas], 0)
|
||||
feat = paddle.concat([query_feat, gallery_feat], 0)
|
||||
logger.info("Using GPU to compute original distance matrix")
|
||||
|
||||
# use L2 distance
|
||||
|
@ -273,8 +254,7 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
|
|||
else:
|
||||
original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \
|
||||
paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t()
|
||||
original_dist = original_dist.addmm(
|
||||
x=feat, y=feat.t(), alpha=-2.0, beta=1.0)
|
||||
original_dist = original_dist.addmm(feat, feat.t(), -2.0, 1.0)
|
||||
original_dist = original_dist.numpy()
|
||||
del feat
|
||||
|
||||
|
@ -298,7 +278,6 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
|
|||
p_k_reciprocal_exp_ind = np.append(p_k_reciprocal_exp_ind,
|
||||
q_k_reciprocal_ind)
|
||||
p_k_reciprocal_exp_ind = np.unique(p_k_reciprocal_exp_ind)
|
||||
|
||||
# reweight distance using gaussian kernel
|
||||
weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind])
|
||||
V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight)
|
||||
|
@ -318,6 +297,7 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
|
|||
for gj in range(num_all):
|
||||
invIndex.append(np.nonzero(V[:, gj])[0])
|
||||
|
||||
# compute jaccard distance
|
||||
jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
|
||||
for p in range(num_query):
|
||||
sum_min = np.zeros(shape=[1, num_all], dtype=np.float16)
|
||||
|
@ -328,7 +308,8 @@ def compute_re_ranking_dist(query_feas: paddle.Tensor,
|
|||
sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj])
|
||||
jaccard_dist[p] = 1 - sum_min / (2 - sum_min)
|
||||
|
||||
final_dist = jaccard_dist * (1 - lamb) + original_dist * lamb
|
||||
# fuse jaccard distance with original distance
|
||||
final_dist = (1 - lamb) * jaccard_dist + lamb * original_dist
|
||||
del original_dist
|
||||
del V
|
||||
del jaccard_dist
|
||||
|
|
|
@ -287,10 +287,10 @@ class Recallk(nn.Layer):
|
|||
keep_mask):
|
||||
metric_dict = dict()
|
||||
|
||||
#get cmc
|
||||
# get cmc
|
||||
choosen_indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=self.descending)
|
||||
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
|
||||
gallery_labels_transpose = gallery_img_id.t()
|
||||
gallery_labels_transpose = paddle.broadcast_to(
|
||||
gallery_labels_transpose,
|
||||
shape=[
|
||||
|
@ -301,18 +301,14 @@ class Recallk(nn.Layer):
|
|||
equal_flag = paddle.equal(choosen_label, query_img_id)
|
||||
if keep_mask is not None:
|
||||
keep_mask = paddle.index_sample(
|
||||
keep_mask.astype('float32'), choosen_indices)
|
||||
equal_flag = paddle.logical_and(equal_flag,
|
||||
keep_mask.astype('bool'))
|
||||
equal_flag = paddle.cast(equal_flag, 'float32')
|
||||
keep_mask.astype("float32"), choosen_indices)
|
||||
equal_flag = equal_flag & keep_mask.astype("bool")
|
||||
equal_flag = paddle.cast(equal_flag, "float32")
|
||||
real_query_num = paddle.sum(equal_flag, axis=1)
|
||||
real_query_num = paddle.sum(
|
||||
paddle.greater_than(real_query_num, paddle.to_tensor(0.)).astype(
|
||||
"float32"))
|
||||
real_query_num = paddle.sum((real_query_num > 0.0).astype("float32"))
|
||||
|
||||
acc_sum = paddle.cumsum(equal_flag, axis=1)
|
||||
mask = paddle.greater_than(acc_sum,
|
||||
paddle.to_tensor(0.)).astype("float32")
|
||||
mask = (acc_sum > 0.0).astype("float32")
|
||||
all_cmc = (paddle.sum(mask, axis=0) / real_query_num).numpy()
|
||||
|
||||
for k in self.topk:
|
||||
|
|
Loading…
Reference in New Issue