mirror of https://github.com/JDAI-CV/fast-reid.git
add contrastive loss
parent
7201a82840
commit
8b309b0f4e
|
@ -37,3 +37,6 @@ model_ts*.txt
|
|||
.vscode
|
||||
_darcs
|
||||
.DS_Store
|
||||
|
||||
# dataset dir
|
||||
datasets
|
||||
|
|
|
@ -121,6 +121,11 @@ _C.MODEL.LOSSES.COSFACE.MARGIN = 0.25
|
|||
_C.MODEL.LOSSES.COSFACE.GAMMA = 128
|
||||
_C.MODEL.LOSSES.COSFACE.SCALE = 1.0
|
||||
|
||||
# Contrastive Loss options
|
||||
_C.MODEL.LOSSES.CONTRASTIVE = CN()
|
||||
_C.MODEL.LOSSES.CONTRASTIVE.MARGIN = 2.0
|
||||
_C.MODEL.LOSSES.CONTRASTIVE.SCALE = 1.0
|
||||
|
||||
# Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo.
|
||||
_C.MODEL.WEIGHTS = ""
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
"""
|
||||
|
||||
from .circle_loss import *
|
||||
from .contrastive_loss import contrastive_loss
|
||||
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
|
||||
from .focal_loss import focal_loss
|
||||
from .triplet_loss import triplet_loss
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/10/11 15:46:33
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : contrastive_loss.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['contrastive_loss']
|
||||
|
||||
|
||||
def contrastive_loss(
|
||||
embedding: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
margin: float) -> torch.Tensor:
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embed1 = embedding[0:len(embedding):2, :]
|
||||
embed2 = embedding[1:len(embedding):2, :]
|
||||
euclidean_distance = F.pairwise_distance(embed1, embed2)
|
||||
return torch.mean(targets * torch.pow(euclidean_distance, 2) +
|
||||
(1 - targets) * torch.pow(torch.clamp(margin - euclidean_distance, min=0), 2))
|
|
@ -88,6 +88,10 @@ class Baseline(nn.Module):
|
|||
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
|
||||
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
|
||||
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
|
||||
},
|
||||
'contrastive': {
|
||||
'margin': cfg.MODEL.LOSSES.CONTRASTIVE.MARGIN,
|
||||
'scale': cfg.MODEL.LOSSES.CONTRASTIVE.SCALE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -143,7 +147,7 @@ class Baseline(nn.Module):
|
|||
# fmt: on
|
||||
|
||||
# Log prediction accuracy
|
||||
log_accuracy(pred_class_logits, gt_labels)
|
||||
# log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
@ -185,4 +189,12 @@ class Baseline(nn.Module):
|
|||
cosface_kwargs.get('gamma'),
|
||||
) * cosface_kwargs.get('scale')
|
||||
|
||||
if 'ContrastiveLoss' in loss_names:
|
||||
contrastive_kwargs = self.loss_kwargs.get('contrastive')
|
||||
loss_dict['loss_contrastive'] = contrastive_loss(
|
||||
pred_features,
|
||||
gt_labels,
|
||||
contrastive_kwargs.get('margin')
|
||||
) * contrastive_kwargs.get('scale')
|
||||
|
||||
return loss_dict
|
||||
|
|
|
@ -18,10 +18,10 @@ MODEL:
|
|||
NUM_CLASSES: 2
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss",)
|
||||
NAME: ("ContrastiveLoss",)
|
||||
|
||||
CE:
|
||||
EPSILON: 0.1
|
||||
CONTRASTIVE:
|
||||
MARGIN: 2.0
|
||||
SCALE: 1.
|
||||
|
||||
INPUT:
|
||||
|
|
Loading…
Reference in New Issue