mirror of https://github.com/JDAI-CV/fast-reid.git
add cls target in pair dataset
parent
8b309b0f4e
commit
1bd41a79fc
|
@ -194,12 +194,15 @@ def pair_batch_collator(batched_inputs):
|
|||
|
||||
images = []
|
||||
targets = []
|
||||
cls_targets = []
|
||||
for elem in batched_inputs:
|
||||
images.append(elem['img1'])
|
||||
images.append(elem['img2'])
|
||||
targets.append(elem['target'])
|
||||
cls_targets.append(elem['cls_target'])
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
targets = torch.tensor(targets)
|
||||
return {'images': images, 'targets': targets}
|
||||
cls_targets = torch.tensor(cls_targets)
|
||||
return {'images': images, 'targets': targets, 'cls_targets': cls_targets}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : contrastive_loss.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .utils import normalize, euclidean_dist
|
||||
|
||||
__all__ = ['contrastive_loss']
|
||||
|
||||
|
@ -12,9 +12,9 @@ def contrastive_loss(
|
|||
embedding: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
margin: float) -> torch.Tensor:
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = normalize(embedding, dim=1)
|
||||
embed1 = embedding[0:len(embedding):2, :]
|
||||
embed2 = embedding[1:len(embedding):2, :]
|
||||
euclidean_distance = F.pairwise_distance(embed1, embed2)
|
||||
euclidean_distance = euclidean_dist(embed1, embed2)
|
||||
return torch.mean(targets * torch.pow(euclidean_distance, 2) +
|
||||
(1 - targets) * torch.pow(torch.clamp(margin - euclidean_distance, min=0), 2))
|
||||
|
|
|
@ -108,13 +108,19 @@ class Baseline(nn.Module):
|
|||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"]
|
||||
|
||||
if "cls_targets" in batched_inputs:
|
||||
cls_targets = batched_inputs['cls_targets']
|
||||
|
||||
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
|
||||
# may be larger than that in the original dataset, so the circle/arcface will
|
||||
# throw an error. We just set all the targets to 0 to avoid this problem.
|
||||
if targets.sum() < 0: targets.zero_()
|
||||
|
||||
outputs = self.heads(features, targets)
|
||||
losses = self.losses(outputs, targets)
|
||||
if "cls_targets" in batched_inputs:
|
||||
outputs = self.heads(features, cls_targets)
|
||||
else:
|
||||
outputs = self.heads(features, targets)
|
||||
losses = self.losses(outputs, cls_targets, targets)
|
||||
return losses
|
||||
else:
|
||||
outputs = self.heads(features)
|
||||
|
@ -134,7 +140,7 @@ class Baseline(nn.Module):
|
|||
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
||||
return images
|
||||
|
||||
def losses(self, outputs, gt_labels):
|
||||
def losses(self, outputs, gt_cls_labels, gt_metric_labels):
|
||||
"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
|
@ -147,7 +153,7 @@ class Baseline(nn.Module):
|
|||
# fmt: on
|
||||
|
||||
# Log prediction accuracy
|
||||
# log_accuracy(pred_class_logits, gt_labels)
|
||||
log_accuracy(pred_class_logits, gt_cls_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
|
|
@ -45,10 +45,10 @@ class PairDataset(Dataset):
|
|||
return {
|
||||
'img1': img1,
|
||||
'img2': img2,
|
||||
'target': label
|
||||
'target': label,
|
||||
'cls_target': idx
|
||||
}
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
return len(self.pos_folders)
|
||||
|
|
Loading…
Reference in New Issue