2023-01-05 15:13:04 +08:00

338 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
import numpy as np
import paddle
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
from ppcls.utils import all_gather
def retrieval_eval(engine, epoch_id=0):
engine.model.eval()
# step1. prepare query and gallery features
if engine.gallery_query_dataloader is not None:
gallery_feas, gallery_label_id, gallery_camera_id = compute_feature(
engine, "gallery_query")
query_feas, query_label_id, query_camera_id = gallery_feas, gallery_label_id, gallery_camera_id
else:
gallery_feas, gallery_label_id, gallery_camera_id = compute_feature(
engine, "gallery")
query_feas, query_label_id, query_camera_id = compute_feature(engine,
"query")
# step2. split features into feature blocks for saving memory
block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [block_size] * (len(query_feas) // block_size)
if len(query_feas) % block_size > 0:
sections.append(len(query_feas) % block_size)
query_feas_blocks = paddle.split(query_feas, sections)
query_camera_id_blocks = (paddle.split(query_camera_id, sections)
if query_camera_id is not None else None)
query_label_id_blocks = paddle.split(query_label_id, sections)
metric_key = None
# step3. compute metric
if engine.eval_loss_func is None:
metric_dict = {metric_key: 0.}
else:
use_reranking = engine.config["Global"].get("re_ranking", False)
logger.info(f"re_ranking={use_reranking}")
metric_dict = {}
if use_reranking:
for _, metric_func in enumerate(
engine.eval_metric_func.metric_func_list):
if hasattr(metric_func,
"descending") and metric_func.descending is True:
metric_func.descending = False
logger.warning(
f"re_ranking=True, set {type_name(metric_func)}.descending set to False"
)
# compute distance matrix
distmat = compute_re_ranking_dist(
query_feas, gallery_feas, engine.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance
camera_id_mask = query_camera_id != gallery_camera_id.t()
image_id_mask = query_label_id != gallery_label_id.t()
keep_mask = paddle.logical_or(image_id_mask, camera_id_mask)
distmat = distmat * keep_mask.astype(query_feas.dtype)
inf_mat = (
paddle.logical_not(keep_mask).astype(query_feas.dtype)) * (
distmat.max() + 1)
distmat = distmat + inf_mat
metric_block = engine.eval_metric_func(distmat, query_label_id,
gallery_label_id, keep_mask)
for key in metric_block:
metric_dict[key] = metric_block[key]
else:
for block_idx, block_fea in enumerate(query_feas_blocks):
distmat = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
if query_camera_id is not None:
query_camera_id_block = query_camera_id_blocks[block_idx]
camera_id_mask = query_camera_id_block != gallery_camera_id.t(
)
query_label_id_block = query_label_id_blocks[block_idx]
image_id_mask = query_label_id_block != gallery_label_id.t(
)
keep_mask = paddle.logical_or(image_id_mask,
camera_id_mask)
distmat = distmat * keep_mask.astype("float32")
else:
keep_mask = None
metric_block = engine.eval_metric_func(
distmat, query_label_id_blocks[block_idx],
gallery_label_id, keep_mask)
for key in metric_block:
if key not in metric_dict:
metric_dict[key] = metric_block[key] * block_fea.shape[
0] / len(query_feas)
else:
metric_dict[key] += metric_block[
key] * block_fea.shape[0] / len(query_feas)
metric_info_list = []
for key in metric_dict:
metric_info_list.append(f"{key}: {metric_dict[key]:.5f}")
if metric_key is None:
metric_key = key
metric_msg = ", ".join(metric_info_list)
logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}")
return metric_dict[metric_key]
def compute_feature(engine, name="gallery"):
has_camera_id = False
all_camera_id = None
if name == "gallery":
dataloader = engine.gallery_dataloader
elif name == "query":
dataloader = engine.query_dataloader
elif name == "gallery_query":
dataloader = engine.gallery_query_dataloader
else:
raise RuntimeError("Only support gallery or query dataset")
batch_feas_list = []
label_id_list = []
camera_id_list = []
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
dataloader)
for idx, batch in enumerate(dataloader): # load is very time-consuming
if idx >= max_iter:
break
if idx % engine.config["Global"]["print_batch_step"] == 0:
logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
)
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3:
has_camera_id = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch[0], batch[1])
else:
out = engine.model(batch[0], batch[1])
if "Student" in out:
out = out["Student"]
# get features
if engine.config["Global"].get("retrieval_feature_from",
"features") == "features":
# use neck's output as features
batch_feas = out["features"]
else:
# use backbone's output as features
batch_feas = out["backbone"]
# do norm(optinal)
if engine.config["Global"].get("feature_normalize", True):
batch_feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, batch_feas_norm)
# do binarize(optinal)
if engine.config["Global"].get("feature_binarize") == "round":
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0
elif engine.config["Global"].get("feature_binarize") == "sign":
batch_feas = paddle.sign(batch_feas).astype("float32")
if paddle.distributed.get_world_size() > 1:
batch_feas_list.append(all_gather(batch_feas))
label_id_list.append(all_gather(batch[1]))
if has_camera_id:
camera_id_list.append(all_gather(batch[2]))
else:
batch_feas_list.append(batch_feas)
label_id_list.append(batch[1])
if has_camera_id:
camera_id_list.append(batch[2])
if engine.use_dali:
dataloader.reset()
all_feas = paddle.concat(batch_feas_list)
all_label_id = paddle.concat(label_id_list)
if has_camera_id:
all_camera_id = paddle.concat(camera_id_list)
# discard redundant padding sample(s) at the end
total_samples = len(
dataloader.dataset) if not engine.use_dali else dataloader.size
all_feas = all_feas[:total_samples]
all_label_id = all_label_id[:total_samples]
if has_camera_id:
all_camera_id = all_camera_id[:total_samples]
logger.info(f"Build {name} done, all feat shape: {all_feas.shape}")
return all_feas, all_label_id, all_camera_id
def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
"""Implementation of k-reciprocal nearest neighbors, i.e. R(p, k)
Args:
rank (np.ndarray): Rank mat with shape of [N, N].
p (int): Probe index.
k (int): Parameter k for k-reciprocal nearest neighbors algorithm.
Returns:
np.ndarray: K-reciprocal nearest neighbors of probe p with shape of [M, ].
"""
# use k+1 for excluding probe index itself
forward_k_neigh_index = rank[p, :k + 1]
backward_k_neigh_index = rank[forward_k_neigh_index, :k + 1]
candidate = np.where(backward_k_neigh_index == p)[0]
return forward_k_neigh_index[candidate]
def compute_re_ranking_dist(query_feas: paddle.Tensor,
gallery_feas: paddle.Tensor,
feature_normed: bool=True,
k1: int=20,
k2: int=6,
lamb: float=0.5) -> paddle.Tensor:
"""
Re-ranking Person Re-identification with k-reciprocal Encoding
Reference: https://arxiv.org/abs/1701.08398
Code refernence: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py
Args:
query_feas (paddle.Tensor): Query features with shape of [num_query, feature_dim].
gallery_feas (paddle.Tensor): Gallery features with shape of [num_gallery, feature_dim].
feature_normed (bool, optional): Whether input features are normalized.
k1 (int, optional): Parameter for K-reciprocal nearest neighbors. Defaults to 20.
k2 (int, optional): Parameter for K-nearest neighbors. Defaults to 6.
lamb (float, optional): Penalty factor. Defaults to 0.5.
Returns:
paddle.Tensor: (1 - lamb) x Dj + lamb x D, with shape of [num_query, num_gallery].
"""
num_query = query_feas.shape[0]
num_gallery = gallery_feas.shape[0]
num_all = num_query + num_gallery
feat = paddle.concat([query_feas, gallery_feas], 0)
logger.info("Using GPU to compute original distance matrix")
# use L2 distance
if feature_normed:
original_dist = 2 - 2 * paddle.matmul(feat, feat, transpose_y=True)
else:
original_dist = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]) + \
paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([num_all, num_all]).t()
original_dist = original_dist.addmm(
x=feat, y=feat.t(), alpha=-2.0, beta=1.0)
original_dist = original_dist.numpy()
del feat
original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
V = np.zeros_like(original_dist).astype(np.float16)
initial_rank = np.argpartition(original_dist, range(1, k1 + 1))
logger.info("Start re-ranking...")
for p in range(num_all):
# compute R(p,k1)
p_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, p, k1)
# compute R*(p,k1)=R(p,k1)R(q,k1/2)
# s.t. |R(p,k1)∩R(q,k1/2)|>=2/3|R(q,k1/2)|, ∀q∈R(p,k1)
p_k_reciprocal_exp_ind = p_k_reciprocal_ind
for _, q in enumerate(p_k_reciprocal_ind):
q_k_reciprocal_ind = k_reciprocal_neighbor(initial_rank, q,
int(np.around(k1 / 2)))
if len(np.intersect1d(p_k_reciprocal_ind, q_k_reciprocal_ind)
) > 2 / 3 * len(q_k_reciprocal_ind):
p_k_reciprocal_exp_ind = np.append(p_k_reciprocal_exp_ind,
q_k_reciprocal_ind)
p_k_reciprocal_exp_ind = np.unique(p_k_reciprocal_exp_ind)
# reweight distance using gaussian kernel
weight = np.exp(-original_dist[p, p_k_reciprocal_exp_ind])
V[p, p_k_reciprocal_exp_ind] = weight / np.sum(weight)
# local query expansion
original_dist = original_dist[:num_query, ]
if k2 > 1:
V_qe = np.zeros_like(V, dtype=np.float16)
for p in range(num_all):
V_qe[p, :] = np.mean(V[initial_rank[p, :k2], :], axis=0)
V = V_qe
del V_qe
del initial_rank
# cache k-reciprocal sets which contains gj
invIndex = []
for gj in range(num_all):
invIndex.append(np.nonzero(V[:, gj])[0])
jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
for p in range(num_query):
sum_min = np.zeros(shape=[1, num_all], dtype=np.float16)
gj_ind = np.nonzero(V[p, :])[0]
gj_ind_inv = [invIndex[gj] for gj in gj_ind]
for j, gj in enumerate(gj_ind):
gi = gj_ind_inv[j]
sum_min[0, gi] += np.minimum(V[p, gj], V[gi, gj])
jaccard_dist[p] = 1 - sum_min / (2 - sum_min)
final_dist = jaccard_dist * (1 - lamb) + original_dist * lamb
del original_dist
del V
del jaccard_dist
final_dist = final_dist[:num_query, num_query:]
final_dist = paddle.to_tensor(final_dist)
return final_dist