mirror of https://github.com/JDAI-CV/fast-reid.git
fix triplet ddp training
Summary: fixup precision alignment when triplet loss with ddppull/365/head
parent
64bf78afee
commit
a00e50d37f
|
@ -5,4 +5,4 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.2.0"
|
||||||
|
|
|
@ -97,11 +97,11 @@ def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
|
||||||
all_embedding = embedding
|
all_embedding = embedding
|
||||||
all_targets = targets
|
all_targets = targets
|
||||||
|
|
||||||
dist_mat = euclidean_dist(embedding, all_embedding)
|
dist_mat = euclidean_dist(all_embedding, all_embedding)
|
||||||
|
|
||||||
N, M = dist_mat.size()
|
N, N = dist_mat.size()
|
||||||
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
|
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t())
|
||||||
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
|
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
|
||||||
|
|
||||||
if hard_mining:
|
if hard_mining:
|
||||||
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
|
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
|
||||||
|
|
|
@ -66,7 +66,7 @@ class Visualizer:
|
||||||
plt.clf()
|
plt.clf()
|
||||||
ax = fig.add_subplot(1, max_rank + 1, 1)
|
ax = fig.add_subplot(1, max_rank + 1, 1)
|
||||||
ax.imshow(query_img)
|
ax.imshow(query_img)
|
||||||
ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
|
ax.set_title('{:.4f}/cam{}'.format(self.all_ap[q_idx], cam_id))
|
||||||
ax.axis("off")
|
ax.axis("off")
|
||||||
for i in range(max_rank):
|
for i in range(max_rank):
|
||||||
if vis_label:
|
if vis_label:
|
||||||
|
@ -176,9 +176,9 @@ class Visualizer:
|
||||||
self.plot_roc_curve(fpr, tpr)
|
self.plot_roc_curve(fpr, tpr)
|
||||||
filepath = os.path.join(output, "roc.jpg")
|
filepath = os.path.join(output, "roc.jpg")
|
||||||
plt.savefig(filepath)
|
plt.savefig(filepath)
|
||||||
self.plot_distribution(pos, neg)
|
# self.plot_distribution(pos, neg)
|
||||||
filepath = os.path.join(output, "pos_neg_dist.jpg")
|
# filepath = os.path.join(output, "pos_neg_dist.jpg")
|
||||||
plt.savefig(filepath)
|
# plt.savefig(filepath)
|
||||||
return fpr, tpr, pos, neg
|
return fpr, tpr, pos, neg
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Reference in New Issue