commit
9561ccaef2
|
@ -404,37 +404,40 @@ class Trainer(object):
|
|||
if query_query_id is not None:
|
||||
query_id_blocks = paddle.split(
|
||||
query_query_id, num_or_sections=sections)
|
||||
image_id_blocks = paddle.split(
|
||||
query_img_id, num_or_sections=sections)
|
||||
image_id_blocks = paddle.split(
|
||||
query_img_id, num_or_sections=sections)
|
||||
metric_key = None
|
||||
|
||||
for block_idx, block_fea in enumerate(fea_blocks):
|
||||
similarity_matrix = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True)
|
||||
if query_query_id is not None:
|
||||
query_id_block = query_id_blocks[block_idx]
|
||||
query_id_mask = (query_id_block != gallery_unique_id.t())
|
||||
|
||||
image_id_block = image_id_blocks[block_idx]
|
||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||
|
||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||
similarity_matrix = similarity_matrix * keep_mask.astype(
|
||||
"float32")
|
||||
if cum_similarity_matrix is None:
|
||||
cum_similarity_matrix = similarity_matrix
|
||||
else:
|
||||
cum_similarity_matrix = paddle.concat(
|
||||
[cum_similarity_matrix, similarity_matrix], axis=0)
|
||||
|
||||
# calc metric
|
||||
if self.eval_metric_func is not None:
|
||||
metric_dict = self.eval_metric_func(cum_similarity_matrix,
|
||||
query_img_id, gallery_img_id)
|
||||
else:
|
||||
if self.eval_metric_func is None:
|
||||
metric_dict = {metric_key: 0.}
|
||||
metric_info_list = []
|
||||
else:
|
||||
metric_dict = dict()
|
||||
for block_idx, block_fea in enumerate(fea_blocks):
|
||||
similarity_matrix = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True)
|
||||
if query_query_id is not None:
|
||||
query_id_block = query_id_blocks[block_idx]
|
||||
query_id_mask = (query_id_block != gallery_unique_id.t())
|
||||
|
||||
image_id_block = image_id_blocks[block_idx]
|
||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||
|
||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||
similarity_matrix = similarity_matrix * keep_mask.astype("float32")
|
||||
|
||||
metric_tmp = self.eval_metric_func(similarity_matrix,image_id_blocks[block_idx], gallery_img_id)
|
||||
|
||||
for key in metric_tmp:
|
||||
if key not in metric_dict:
|
||||
metric_dict[key] = metric_tmp[key]
|
||||
else:
|
||||
metric_dict[key] += metric_tmp[key]
|
||||
|
||||
num_sections = len(fea_blocks)
|
||||
for key in metric_dict:
|
||||
metric_dict[key] = metric_dict[key]/num_sections
|
||||
|
||||
metric_info_list = []
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
|
@ -442,7 +445,8 @@ class Trainer(object):
|
|||
metric_msg = ", ".join(metric_info_list)
|
||||
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
||||
|
||||
return metric_dict[metric_key]
|
||||
return metric_dict[metric_key]
|
||||
|
||||
|
||||
def _cal_feature(self, name='gallery'):
|
||||
all_feas = None
|
||||
|
|
Loading…
Reference in New Issue