diff --git a/fastreid/evaluation/pair_evaluator.py b/fastreid/evaluation/pair_evaluator.py index 989a65d..4837581 100644 --- a/fastreid/evaluation/pair_evaluator.py +++ b/fastreid/evaluation/pair_evaluator.py @@ -22,11 +22,11 @@ class PairEvaluator(DatasetEvaluator): self._output_dir = output_dir self._cpu_device = torch.device('cpu') self._predictions = [] - self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98] + self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98] def reset(self): self._predictions = [] - + def process(self, inputs, outputs): embedding = outputs['features'].to(self._cpu_device) embedding = embedding.view(embedding.size(0) * 2, -1) @@ -41,31 +41,31 @@ class PairEvaluator(DatasetEvaluator): 'labels': inputs["targets"].to(self._cpu_device).numpy() } self._predictions.append(prediction) - + def evaluate(self): if comm.get_world_size() > 1: comm.synchronize() predictions = comm.gather(self._predictions, dst=0) predictions = list(itertools.chain(*predictions)) - if not comm.is_main_process(): + if not comm.is_main_process(): return {} else: predictions = self._predictions - + all_distances = [] all_labels = [] for prediction in predictions: all_distances.append(prediction['distances']) all_labels.append(prediction['labels']) - + all_distances = np.concatenate(all_distances) all_labels = np.concatenate(all_labels) # 计算这3个总体值,还有给定阈值下的precision, recall, f1 acc = skmetrics.accuracy_score(all_labels, all_distances > 0.5) ap = skmetrics.average_precision_score(all_labels, all_distances) - auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc + auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc precisions = [] recalls = [] @@ -74,11 +74,11 @@ class PairEvaluator(DatasetEvaluator): precision = skmetrics.precision_score(all_labels, all_distances >= thresh, zero_division=0) recall = skmetrics.recall_score(all_labels, all_distances >= thresh, zero_division=0) f1 = 2 * precision * recall / (precision + recall + 1e-12) - + precisions.append(precision) recalls.append(recall) f1s.append(f1) - + self._results = OrderedDict() self._results['Acc'] = acc self._results['Ap'] = ap diff --git a/fastreid/evaluation/pair_score_evaluator.py b/fastreid/evaluation/pair_score_evaluator.py new file mode 100644 index 0000000..ddf032e --- /dev/null +++ b/fastreid/evaluation/pair_score_evaluator.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/10/27 15:56:52 +# @Author : zuchen.wang@vipshop.com +# @File : pair_score_evaluator.py +# coding: utf-8 + +import logging + +import torch + +from fastreid.modeling.losses.utils import normalize +from .pair_evaluator import PairEvaluator + +logger = logging.getLogger(__name__) + + +class PairScoreEvaluator(PairEvaluator): + + def process(self, inputs, outputs): + prediction = { + 'distances': outputs['cls_outputs'].to(self._cpu_device).numpy(), + 'labels': inputs["targets"].to(self._cpu_device).numpy() + } + self._predictions.append(prediction) diff --git a/fastreid/evaluation/pari_distance_evaluator.py b/fastreid/evaluation/pari_distance_evaluator.py new file mode 100644 index 0000000..a1c9e05 --- /dev/null +++ b/fastreid/evaluation/pari_distance_evaluator.py @@ -0,0 +1,28 @@ +# coding: utf-8 + +import logging + +import torch + +from fastreid.modeling.losses.utils import normalize +from .pair_evaluator import PairEvaluator + +logger = logging.getLogger(__name__) + + +class PairDistanceEvaluator(PairEvaluator): + + def process(self, inputs, outputs): + embedding = outputs['features'].to(self._cpu_device) + embedding = embedding.view(embedding.size(0) * 2, -1) + embedding = normalize(embedding, axis=-1) + embed1 = embedding[0:len(embedding):2, :] + embed2 = embedding[1:len(embedding):2, :] + distances = torch.mul(embed1, embed2).sum(-1).numpy() + + # print(distances) + prediction = { + 'distances': distances, + 'labels': inputs["targets"].to(self._cpu_device).numpy() + } + self._predictions.append(prediction) diff --git a/projects/FastShoe/fastshoe/trainer.py b/projects/FastShoe/fastshoe/trainer.py index 82eb400..97636c5 100644 --- a/projects/FastShoe/fastshoe/trainer.py +++ b/projects/FastShoe/fastshoe/trainer.py @@ -13,7 +13,7 @@ from fastreid.data.datasets import DATASET_REGISTRY from fastreid.utils import comm from fastreid.data.transforms import build_transforms from fastreid.data.build import build_reid_train_loader, build_reid_test_loader -from fastreid.evaluation.pair_evaluator import PairEvaluator +from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator from projects.FastShoe.fastshoe.data import PairDataset @@ -85,4 +85,4 @@ class PairTrainer(DefaultTrainer): @classmethod def build_evaluator(cls, cfg, dataset_name, output_dir=None): data_loader = cls.build_test_loader(cfg, dataset_name) - return data_loader, PairEvaluator(cfg, output_dir) + return data_loader, PairScoreEvaluator(cfg, output_dir)