fast-reid/fastreid/modeling/losses/contrastive_loss.py

20 lines
638 B
Python

# -*- coding: utf-8 -*-
# @Time : 2021/10/11 15:46:33
# @Author : zuchen.wang@vipshop.com
# @File : contrastive_loss.py
import torch
from .utils import normalize, euclidean_dist
__all__ = ['contrastive_loss']
def contrastive_loss(
query_feat: torch.Tensor,
gallery_feat: torch.Tensor,
targets: torch.Tensor,
margin: float) -> torch.Tensor:
distance = torch.sqrt(torch.sum(torch.pow(query_feat - gallery_feat, 2), -1))
loss1 = 0.5 * targets * torch.pow(distance, 2)
loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=0), 2)
return torch.mean(loss1 + loss2)