add contrastive loss

pull/608/head
zuchen.wang 2021-10-11 16:12:00 +08:00
parent 7201a82840
commit 8b309b0f4e
6 changed files with 45 additions and 4 deletions
fastreid
projects/FastShoe/configs

3
.gitignore vendored
View File

@ -37,3 +37,6 @@ model_ts*.txt
.vscode
_darcs
.DS_Store
# dataset dir
datasets

View File

@ -121,6 +121,11 @@ _C.MODEL.LOSSES.COSFACE.MARGIN = 0.25
_C.MODEL.LOSSES.COSFACE.GAMMA = 128
_C.MODEL.LOSSES.COSFACE.SCALE = 1.0
# Contrastive Loss options
_C.MODEL.LOSSES.CONTRASTIVE = CN()
_C.MODEL.LOSSES.CONTRASTIVE.MARGIN = 2.0
_C.MODEL.LOSSES.CONTRASTIVE.SCALE = 1.0
# Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo.
_C.MODEL.WEIGHTS = ""

View File

@ -5,6 +5,7 @@
"""
from .circle_loss import *
from .contrastive_loss import contrastive_loss
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
from .focal_loss import focal_loss
from .triplet_loss import triplet_loss

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/11 15:46:33
# @Author : zuchen.wang@vipshop.com
# @File : contrastive_loss.py
import torch
import torch.nn.functional as F
__all__ = ['contrastive_loss']
def contrastive_loss(
embedding: torch.Tensor,
targets: torch.Tensor,
margin: float) -> torch.Tensor:
embedding = F.normalize(embedding, dim=1)
embed1 = embedding[0:len(embedding):2, :]
embed2 = embedding[1:len(embedding):2, :]
euclidean_distance = F.pairwise_distance(embed1, embed2)
return torch.mean(targets * torch.pow(euclidean_distance, 2) +
(1 - targets) * torch.pow(torch.clamp(margin - euclidean_distance, min=0), 2))

View File

@ -88,6 +88,10 @@ class Baseline(nn.Module):
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
},
'contrastive': {
'margin': cfg.MODEL.LOSSES.CONTRASTIVE.MARGIN,
'scale': cfg.MODEL.LOSSES.CONTRASTIVE.SCALE
}
}
}
@ -143,7 +147,7 @@ class Baseline(nn.Module):
# fmt: on
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
# log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self.loss_kwargs['loss_names']
@ -185,4 +189,12 @@ class Baseline(nn.Module):
cosface_kwargs.get('gamma'),
) * cosface_kwargs.get('scale')
if 'ContrastiveLoss' in loss_names:
contrastive_kwargs = self.loss_kwargs.get('contrastive')
loss_dict['loss_contrastive'] = contrastive_loss(
pred_features,
gt_labels,
contrastive_kwargs.get('margin')
) * contrastive_kwargs.get('scale')
return loss_dict

View File

@ -18,10 +18,10 @@ MODEL:
NUM_CLASSES: 2
LOSSES:
NAME: ("CrossEntropyLoss",)
NAME: ("ContrastiveLoss",)
CE:
EPSILON: 0.1
CONTRASTIVE:
MARGIN: 2.0
SCALE: 1.
INPUT: