commit
9561ccaef2
|
@ -404,37 +404,40 @@ class Trainer(object):
|
||||||
if query_query_id is not None:
|
if query_query_id is not None:
|
||||||
query_id_blocks = paddle.split(
|
query_id_blocks = paddle.split(
|
||||||
query_query_id, num_or_sections=sections)
|
query_query_id, num_or_sections=sections)
|
||||||
image_id_blocks = paddle.split(
|
image_id_blocks = paddle.split(
|
||||||
query_img_id, num_or_sections=sections)
|
query_img_id, num_or_sections=sections)
|
||||||
metric_key = None
|
metric_key = None
|
||||||
|
|
||||||
for block_idx, block_fea in enumerate(fea_blocks):
|
if self.eval_metric_func is None:
|
||||||
similarity_matrix = paddle.matmul(
|
|
||||||
block_fea, gallery_feas, transpose_y=True)
|
|
||||||
if query_query_id is not None:
|
|
||||||
query_id_block = query_id_blocks[block_idx]
|
|
||||||
query_id_mask = (query_id_block != gallery_unique_id.t())
|
|
||||||
|
|
||||||
image_id_block = image_id_blocks[block_idx]
|
|
||||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
|
||||||
|
|
||||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
|
||||||
similarity_matrix = similarity_matrix * keep_mask.astype(
|
|
||||||
"float32")
|
|
||||||
if cum_similarity_matrix is None:
|
|
||||||
cum_similarity_matrix = similarity_matrix
|
|
||||||
else:
|
|
||||||
cum_similarity_matrix = paddle.concat(
|
|
||||||
[cum_similarity_matrix, similarity_matrix], axis=0)
|
|
||||||
|
|
||||||
# calc metric
|
|
||||||
if self.eval_metric_func is not None:
|
|
||||||
metric_dict = self.eval_metric_func(cum_similarity_matrix,
|
|
||||||
query_img_id, gallery_img_id)
|
|
||||||
else:
|
|
||||||
metric_dict = {metric_key: 0.}
|
metric_dict = {metric_key: 0.}
|
||||||
metric_info_list = []
|
else:
|
||||||
|
metric_dict = dict()
|
||||||
|
for block_idx, block_fea in enumerate(fea_blocks):
|
||||||
|
similarity_matrix = paddle.matmul(
|
||||||
|
block_fea, gallery_feas, transpose_y=True)
|
||||||
|
if query_query_id is not None:
|
||||||
|
query_id_block = query_id_blocks[block_idx]
|
||||||
|
query_id_mask = (query_id_block != gallery_unique_id.t())
|
||||||
|
|
||||||
|
image_id_block = image_id_blocks[block_idx]
|
||||||
|
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||||
|
|
||||||
|
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||||
|
similarity_matrix = similarity_matrix * keep_mask.astype("float32")
|
||||||
|
|
||||||
|
metric_tmp = self.eval_metric_func(similarity_matrix,image_id_blocks[block_idx], gallery_img_id)
|
||||||
|
|
||||||
|
for key in metric_tmp:
|
||||||
|
if key not in metric_dict:
|
||||||
|
metric_dict[key] = metric_tmp[key]
|
||||||
|
else:
|
||||||
|
metric_dict[key] += metric_tmp[key]
|
||||||
|
|
||||||
|
num_sections = len(fea_blocks)
|
||||||
|
for key in metric_dict:
|
||||||
|
metric_dict[key] = metric_dict[key]/num_sections
|
||||||
|
|
||||||
|
metric_info_list = []
|
||||||
for key in metric_dict:
|
for key in metric_dict:
|
||||||
if metric_key is None:
|
if metric_key is None:
|
||||||
metric_key = key
|
metric_key = key
|
||||||
|
@ -442,7 +445,8 @@ class Trainer(object):
|
||||||
metric_msg = ", ".join(metric_info_list)
|
metric_msg = ", ".join(metric_info_list)
|
||||||
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
|
||||||
|
|
||||||
return metric_dict[metric_key]
|
return metric_dict[metric_key]
|
||||||
|
|
||||||
|
|
||||||
def _cal_feature(self, name='gallery'):
|
def _cal_feature(self, name='gallery'):
|
||||||
all_feas = None
|
all_feas = None
|
||||||
|
|
Loading…
Reference in New Issue