mirror of https://github.com/JDAI-CV/fast-reid.git
fix triplet ddp training
Summary: fixup precision alignment when triplet loss with ddppull/365/head
parent
64bf78afee
commit
a00e50d37f
|
@ -5,4 +5,4 @@
|
|||
"""
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.2.0"
|
||||
|
|
|
@ -97,11 +97,11 @@ def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
|
|||
all_embedding = embedding
|
||||
all_targets = targets
|
||||
|
||||
dist_mat = euclidean_dist(embedding, all_embedding)
|
||||
dist_mat = euclidean_dist(all_embedding, all_embedding)
|
||||
|
||||
N, M = dist_mat.size()
|
||||
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
|
||||
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
|
||||
N, N = dist_mat.size()
|
||||
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t())
|
||||
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
|
||||
|
||||
if hard_mining:
|
||||
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
|
||||
|
|
|
@ -66,7 +66,7 @@ class Visualizer:
|
|||
plt.clf()
|
||||
ax = fig.add_subplot(1, max_rank + 1, 1)
|
||||
ax.imshow(query_img)
|
||||
ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
|
||||
ax.set_title('{:.4f}/cam{}'.format(self.all_ap[q_idx], cam_id))
|
||||
ax.axis("off")
|
||||
for i in range(max_rank):
|
||||
if vis_label:
|
||||
|
@ -112,7 +112,7 @@ class Visualizer:
|
|||
# axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
|
||||
if vis_label:
|
||||
label_indice = np.where(cmc == 1)[0]
|
||||
if label_sort == "ascending": label_indice = label_indice[::-1]
|
||||
if label_sort == "ascending": label_indice = label_indice[::-1]
|
||||
label_indice = label_indice[:max_rank]
|
||||
for i in range(max_rank):
|
||||
if i >= len(label_indice): break
|
||||
|
@ -176,9 +176,9 @@ class Visualizer:
|
|||
self.plot_roc_curve(fpr, tpr)
|
||||
filepath = os.path.join(output, "roc.jpg")
|
||||
plt.savefig(filepath)
|
||||
self.plot_distribution(pos, neg)
|
||||
filepath = os.path.join(output, "pos_neg_dist.jpg")
|
||||
plt.savefig(filepath)
|
||||
# self.plot_distribution(pos, neg)
|
||||
# filepath = os.path.join(output, "pos_neg_dist.jpg")
|
||||
# plt.savefig(filepath)
|
||||
return fpr, tpr, pos, neg
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue