mirror of https://github.com/JDAI-CV/fast-reid.git
refactor evaluator
parent
6687df06e0
commit
3880b23ab4
|
@ -0,0 +1,24 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2021/10/27 15:56:52
|
||||||
|
# @Author : zuchen.wang@vipshop.com
|
||||||
|
# @File : pair_score_evaluator.py
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from fastreid.modeling.losses.utils import normalize
|
||||||
|
from .pair_evaluator import PairEvaluator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PairScoreEvaluator(PairEvaluator):
|
||||||
|
|
||||||
|
def process(self, inputs, outputs):
|
||||||
|
prediction = {
|
||||||
|
'distances': outputs['cls_outputs'].to(self._cpu_device).numpy(),
|
||||||
|
'labels': inputs["targets"].to(self._cpu_device).numpy()
|
||||||
|
}
|
||||||
|
self._predictions.append(prediction)
|
|
@ -0,0 +1,28 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from fastreid.modeling.losses.utils import normalize
|
||||||
|
from .pair_evaluator import PairEvaluator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PairDistanceEvaluator(PairEvaluator):
|
||||||
|
|
||||||
|
def process(self, inputs, outputs):
|
||||||
|
embedding = outputs['features'].to(self._cpu_device)
|
||||||
|
embedding = embedding.view(embedding.size(0) * 2, -1)
|
||||||
|
embedding = normalize(embedding, axis=-1)
|
||||||
|
embed1 = embedding[0:len(embedding):2, :]
|
||||||
|
embed2 = embedding[1:len(embedding):2, :]
|
||||||
|
distances = torch.mul(embed1, embed2).sum(-1).numpy()
|
||||||
|
|
||||||
|
# print(distances)
|
||||||
|
prediction = {
|
||||||
|
'distances': distances,
|
||||||
|
'labels': inputs["targets"].to(self._cpu_device).numpy()
|
||||||
|
}
|
||||||
|
self._predictions.append(prediction)
|
|
@ -13,7 +13,7 @@ from fastreid.data.datasets import DATASET_REGISTRY
|
||||||
from fastreid.utils import comm
|
from fastreid.utils import comm
|
||||||
from fastreid.data.transforms import build_transforms
|
from fastreid.data.transforms import build_transforms
|
||||||
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
|
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
|
||||||
from fastreid.evaluation.pair_evaluator import PairEvaluator
|
from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator
|
||||||
from projects.FastShoe.fastshoe.data import PairDataset
|
from projects.FastShoe.fastshoe.data import PairDataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,4 +85,4 @@ class PairTrainer(DefaultTrainer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
|
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
|
||||||
data_loader = cls.build_test_loader(cfg, dataset_name)
|
data_loader = cls.build_test_loader(cfg, dataset_name)
|
||||||
return data_loader, PairEvaluator(cfg, output_dir)
|
return data_loader, PairScoreEvaluator(cfg, output_dir)
|
||||||
|
|
Loading…
Reference in New Issue