mirror of https://github.com/JDAI-CV/fast-reid.git
add contrastive loss
parent
7201a82840
commit
8b309b0f4e
|
@ -37,3 +37,6 @@ model_ts*.txt
|
||||||
.vscode
|
.vscode
|
||||||
_darcs
|
_darcs
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# dataset dir
|
||||||
|
datasets
|
||||||
|
|
|
@ -121,6 +121,11 @@ _C.MODEL.LOSSES.COSFACE.MARGIN = 0.25
|
||||||
_C.MODEL.LOSSES.COSFACE.GAMMA = 128
|
_C.MODEL.LOSSES.COSFACE.GAMMA = 128
|
||||||
_C.MODEL.LOSSES.COSFACE.SCALE = 1.0
|
_C.MODEL.LOSSES.COSFACE.SCALE = 1.0
|
||||||
|
|
||||||
|
# Contrastive Loss options
|
||||||
|
_C.MODEL.LOSSES.CONTRASTIVE = CN()
|
||||||
|
_C.MODEL.LOSSES.CONTRASTIVE.MARGIN = 2.0
|
||||||
|
_C.MODEL.LOSSES.CONTRASTIVE.SCALE = 1.0
|
||||||
|
|
||||||
# Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo.
|
# Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo.
|
||||||
_C.MODEL.WEIGHTS = ""
|
_C.MODEL.WEIGHTS = ""
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .circle_loss import *
|
from .circle_loss import *
|
||||||
|
from .contrastive_loss import contrastive_loss
|
||||||
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
|
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
|
||||||
from .focal_loss import focal_loss
|
from .focal_loss import focal_loss
|
||||||
from .triplet_loss import triplet_loss
|
from .triplet_loss import triplet_loss
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2021/10/11 15:46:33
|
||||||
|
# @Author : zuchen.wang@vipshop.com
|
||||||
|
# @File : contrastive_loss.py
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
__all__ = ['contrastive_loss']
|
||||||
|
|
||||||
|
|
||||||
|
def contrastive_loss(
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
margin: float) -> torch.Tensor:
|
||||||
|
embedding = F.normalize(embedding, dim=1)
|
||||||
|
embed1 = embedding[0:len(embedding):2, :]
|
||||||
|
embed2 = embedding[1:len(embedding):2, :]
|
||||||
|
euclidean_distance = F.pairwise_distance(embed1, embed2)
|
||||||
|
return torch.mean(targets * torch.pow(euclidean_distance, 2) +
|
||||||
|
(1 - targets) * torch.pow(torch.clamp(margin - euclidean_distance, min=0), 2))
|
|
@ -88,6 +88,10 @@ class Baseline(nn.Module):
|
||||||
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
|
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
|
||||||
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
|
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
|
||||||
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
|
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
|
||||||
|
},
|
||||||
|
'contrastive': {
|
||||||
|
'margin': cfg.MODEL.LOSSES.CONTRASTIVE.MARGIN,
|
||||||
|
'scale': cfg.MODEL.LOSSES.CONTRASTIVE.SCALE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -143,7 +147,7 @@ class Baseline(nn.Module):
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Log prediction accuracy
|
# Log prediction accuracy
|
||||||
log_accuracy(pred_class_logits, gt_labels)
|
# log_accuracy(pred_class_logits, gt_labels)
|
||||||
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
loss_names = self.loss_kwargs['loss_names']
|
loss_names = self.loss_kwargs['loss_names']
|
||||||
|
@ -185,4 +189,12 @@ class Baseline(nn.Module):
|
||||||
cosface_kwargs.get('gamma'),
|
cosface_kwargs.get('gamma'),
|
||||||
) * cosface_kwargs.get('scale')
|
) * cosface_kwargs.get('scale')
|
||||||
|
|
||||||
|
if 'ContrastiveLoss' in loss_names:
|
||||||
|
contrastive_kwargs = self.loss_kwargs.get('contrastive')
|
||||||
|
loss_dict['loss_contrastive'] = contrastive_loss(
|
||||||
|
pred_features,
|
||||||
|
gt_labels,
|
||||||
|
contrastive_kwargs.get('margin')
|
||||||
|
) * contrastive_kwargs.get('scale')
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
|
@ -18,10 +18,10 @@ MODEL:
|
||||||
NUM_CLASSES: 2
|
NUM_CLASSES: 2
|
||||||
|
|
||||||
LOSSES:
|
LOSSES:
|
||||||
NAME: ("CrossEntropyLoss",)
|
NAME: ("ContrastiveLoss",)
|
||||||
|
|
||||||
CE:
|
CONTRASTIVE:
|
||||||
EPSILON: 0.1
|
MARGIN: 2.0
|
||||||
SCALE: 1.
|
SCALE: 1.
|
||||||
|
|
||||||
INPUT:
|
INPUT:
|
||||||
|
|
Loading…
Reference in New Issue