From a00e50d37f89f832136e6d96c87075b181360c9d Mon Sep 17 00:00:00 2001 From: liaoxingyu2 Date: Fri, 6 Nov 2020 11:01:10 +0800 Subject: [PATCH] fix triplet ddp training Summary: fixup precision alignment when triplet loss with ddp --- fastreid/__init__.py | 2 +- fastreid/modeling/losses/triplet_loss.py | 8 ++++---- fastreid/utils/visualizer.py | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fastreid/__init__.py b/fastreid/__init__.py index eaf1de3..94eab6e 100644 --- a/fastreid/__init__.py +++ b/fastreid/__init__.py @@ -5,4 +5,4 @@ """ -__version__ = "0.1.0" \ No newline at end of file +__version__ = "0.2.0" diff --git a/fastreid/modeling/losses/triplet_loss.py b/fastreid/modeling/losses/triplet_loss.py index f7d8bba..cc3f624 100644 --- a/fastreid/modeling/losses/triplet_loss.py +++ b/fastreid/modeling/losses/triplet_loss.py @@ -97,11 +97,11 @@ def triplet_loss(embedding, targets, margin, norm_feat, hard_mining): all_embedding = embedding all_targets = targets - dist_mat = euclidean_dist(embedding, all_embedding) + dist_mat = euclidean_dist(all_embedding, all_embedding) - N, M = dist_mat.size() - is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()) - is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t()) + N, N = dist_mat.size() + is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()) + is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) if hard_mining: dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg) diff --git a/fastreid/utils/visualizer.py b/fastreid/utils/visualizer.py index 37ae1dd..5a06abd 100644 --- a/fastreid/utils/visualizer.py +++ b/fastreid/utils/visualizer.py @@ -66,7 +66,7 @@ class Visualizer: plt.clf() ax = fig.add_subplot(1, max_rank + 1, 1) ax.imshow(query_img) - ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id)) + ax.set_title('{:.4f}/cam{}'.format(self.all_ap[q_idx], cam_id)) ax.axis("off") for i in range(max_rank): if vis_label: @@ -112,7 +112,7 @@ class Visualizer: # axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet') if vis_label: label_indice = np.where(cmc == 1)[0] - if label_sort == "ascending": label_indice = label_indice[::-1] + if label_sort == "ascending": label_indice = label_indice[::-1] label_indice = label_indice[:max_rank] for i in range(max_rank): if i >= len(label_indice): break @@ -176,9 +176,9 @@ class Visualizer: self.plot_roc_curve(fpr, tpr) filepath = os.path.join(output, "roc.jpg") plt.savefig(filepath) - self.plot_distribution(pos, neg) - filepath = os.path.join(output, "pos_neg_dist.jpg") - plt.savefig(filepath) + # self.plot_distribution(pos, neg) + # filepath = os.path.join(output, "pos_neg_dist.jpg") + # plt.savefig(filepath) return fpr, tpr, pos, neg @staticmethod