fix contrastive loss

pull/608/head
zuchen.wang 2021-11-04 16:24:18 +08:00
parent cc0b9e9612
commit 1d8451a23a
5 changed files with 33 additions and 33 deletions

View File

@ -1,6 +1,8 @@
from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset
from .reid_evaluation import ReidEvaluator
from .clas_evaluator import ClasEvaluator
from .pair_distance_evaluator import PairDistanceEvaluator
from .pair_score_evaluator import PairScoreEvaluator
from .testing import print_csv_format, verify_results
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,25 @@
# coding: utf-8
import logging
import torch
from .pair_evaluator import PairEvaluator
logger = logging.getLogger(__name__)
class PairDistanceEvaluator(PairEvaluator):
def process(self, inputs, outputs):
query_feat = outputs['query_feature']
gallery_feat = outputs['gallery_feature']
distances = torch.sum(query_feat * gallery_feat, -1)
# print(distances)
prediction = {
'distances': distances.to(self._cpu_device).numpy(),
'labels': inputs["targets"].to(self._cpu_device).numpy()
}
self._predictions.append(prediction)

View File

@ -1,28 +0,0 @@
# coding: utf-8
import logging
import torch
from fastreid.modeling.losses.utils import normalize
from .pair_evaluator import PairEvaluator
logger = logging.getLogger(__name__)
class PairDistanceEvaluator(PairEvaluator):
def process(self, inputs, outputs):
embedding = outputs['features'].to(self._cpu_device)
embedding = embedding.view(embedding.size(0) * 2, -1)
embedding = normalize(embedding, axis=-1)
embed1 = embedding[0:len(embedding):2, :]
embed2 = embedding[1:len(embedding):2, :]
distances = torch.mul(embed1, embed2).sum(-1).numpy()
# print(distances)
prediction = {
'distances': distances,
'labels': inputs["targets"].to(self._cpu_device).numpy()
}
self._predictions.append(prediction)

View File

@ -13,6 +13,7 @@ def contrastive_loss(
gallery_feat: torch.Tensor,
targets: torch.Tensor,
margin: float) -> torch.Tensor:
euclidean_distance = torch.sqrt(torch.sum(torch.pow(query_feat - gallery_feat, 2), -1))
return torch.mean(targets * torch.pow(euclidean_distance, 2) +
(1 - targets) * torch.pow(torch.clamp(margin - euclidean_distance, min=0), 2))
distance = torch.sqrt(torch.sum(torch.pow(query_feat - gallery_feat, 2), -1))
loss1 = 0.5 * targets * torch.pow(distance, 2)
loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=0), 2)
return torch.mean(loss1 + loss2)

View File

@ -13,7 +13,7 @@ from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.utils import comm
from fastreid.data.transforms import build_transforms
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator
from fastreid.evaluation import PairScoreEvaluator, PairDistanceEvaluator
from projects.FastShoe.fastshoe.data import PairDataset
@ -84,4 +84,4 @@ class PairTrainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
data_loader = cls.build_test_loader(cfg, dataset_name)
return data_loader, PairScoreEvaluator(cfg, output_dir)
return data_loader, PairDistanceEvaluator(cfg, output_dir)