pull/790/head
weishengyu 2021-06-05 16:12:35 +08:00
parent cd7f606a5f
commit 81c864b96d
1 changed files with 4 additions and 4 deletions

View File

@ -104,7 +104,7 @@ class Trainer(object):
metric_config = self.config.get("Metric", None)
if metric_config is not None:
metric_config = metric_config["Train"]
self.train_metric_func = build_metrics(metric_config)
self.train_metric_func = build_metrics(metric_config)
if self.train_dataloader is None:
self.train_dataloader = build_dataloader(self.config["DataLoader"],
@ -353,12 +353,12 @@ class Trainer(object):
block_fea, gallery_feas, transpose_y=True)
if query_camera_id is not None:
camera_id_block = camera_id_blocks[block_idx]
camera_id_same = (camera_id_block != gallery_camera_id.t())
camera_id_mask = (camera_id_block != gallery_camera_id.t())
image_id_block = image_id_blocks[block_idx]
image_id_same = (image_id_block != gallery_img_id.t())
image_id_mask = (image_id_block != gallery_img_id.t())
keep_mask = paddle.logical_or(camera_id_same, image_id_same)
keep_mask = paddle.logical_or(camera_id_mask, image_id_mask)
similarity_matrix = similarity_matrix * keep_mask.astype(
"float32")
if cum_similarity_matrix is None: