fix triplet ddp training

Summary: fixup precision alignment when triplet loss with ddp
pull/365/head
liaoxingyu2 2020-11-06 11:01:10 +08:00
parent 64bf78afee
commit a00e50d37f
3 changed files with 10 additions and 10 deletions

View File

@ -5,4 +5,4 @@
"""
__version__ = "0.1.0"
__version__ = "0.2.0"

View File

@ -97,11 +97,11 @@ def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
all_embedding = embedding
all_targets = targets
dist_mat = euclidean_dist(embedding, all_embedding)
dist_mat = euclidean_dist(all_embedding, all_embedding)
N, M = dist_mat.size()
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
N, N = dist_mat.size()
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t())
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
if hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)

View File

@ -66,7 +66,7 @@ class Visualizer:
plt.clf()
ax = fig.add_subplot(1, max_rank + 1, 1)
ax.imshow(query_img)
ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
ax.set_title('{:.4f}/cam{}'.format(self.all_ap[q_idx], cam_id))
ax.axis("off")
for i in range(max_rank):
if vis_label:
@ -112,7 +112,7 @@ class Visualizer:
# axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
if vis_label:
label_indice = np.where(cmc == 1)[0]
if label_sort == "ascending": label_indice = label_indice[::-1]
if label_sort == "ascending": label_indice = label_indice[::-1]
label_indice = label_indice[:max_rank]
for i in range(max_rank):
if i >= len(label_indice): break
@ -176,9 +176,9 @@ class Visualizer:
self.plot_roc_curve(fpr, tpr)
filepath = os.path.join(output, "roc.jpg")
plt.savefig(filepath)
self.plot_distribution(pos, neg)
filepath = os.path.join(output, "pos_neg_dist.jpg")
plt.savefig(filepath)
# self.plot_distribution(pos, neg)
# filepath = os.path.join(output, "pos_neg_dist.jpg")
# plt.savefig(filepath)
return fpr, tpr, pos, neg
@staticmethod