From 8b309b0f4e7e7fafea5d601397c1a86378071bd5 Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Mon, 11 Oct 2021 16:12:00 +0800 Subject: [PATCH] add contrastive loss --- .gitignore | 3 +++ fastreid/config/defaults.py | 5 +++++ fastreid/modeling/losses/__init__.py | 1 + fastreid/modeling/losses/contrastive_loss.py | 20 ++++++++++++++++++++ fastreid/modeling/meta_arch/baseline.py | 14 +++++++++++++- projects/FastShoe/configs/base-pair.yaml | 6 +++--- 6 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 fastreid/modeling/losses/contrastive_loss.py diff --git a/.gitignore b/.gitignore index cc5a199..6f5ed37 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,6 @@ model_ts*.txt .vscode _darcs .DS_Store + +# dataset dir +datasets diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index 2eefca2..520abbb 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -121,6 +121,11 @@ _C.MODEL.LOSSES.COSFACE.MARGIN = 0.25 _C.MODEL.LOSSES.COSFACE.GAMMA = 128 _C.MODEL.LOSSES.COSFACE.SCALE = 1.0 +# Contrastive Loss options +_C.MODEL.LOSSES.CONTRASTIVE = CN() +_C.MODEL.LOSSES.CONTRASTIVE.MARGIN = 2.0 +_C.MODEL.LOSSES.CONTRASTIVE.SCALE = 1.0 + # Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo. _C.MODEL.WEIGHTS = "" diff --git a/fastreid/modeling/losses/__init__.py b/fastreid/modeling/losses/__init__.py index 4ce007b..92e7b10 100644 --- a/fastreid/modeling/losses/__init__.py +++ b/fastreid/modeling/losses/__init__.py @@ -5,6 +5,7 @@ """ from .circle_loss import * +from .contrastive_loss import contrastive_loss from .cross_entroy_loss import cross_entropy_loss, log_accuracy from .focal_loss import focal_loss from .triplet_loss import triplet_loss diff --git a/fastreid/modeling/losses/contrastive_loss.py b/fastreid/modeling/losses/contrastive_loss.py new file mode 100644 index 0000000..bfeaea3 --- /dev/null +++ b/fastreid/modeling/losses/contrastive_loss.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/10/11 15:46:33 +# @Author : zuchen.wang@vipshop.com +# @File : contrastive_loss.py +import torch +import torch.nn.functional as F + +__all__ = ['contrastive_loss'] + + +def contrastive_loss( + embedding: torch.Tensor, + targets: torch.Tensor, + margin: float) -> torch.Tensor: + embedding = F.normalize(embedding, dim=1) + embed1 = embedding[0:len(embedding):2, :] + embed2 = embedding[1:len(embedding):2, :] + euclidean_distance = F.pairwise_distance(embed1, embed2) + return torch.mean(targets * torch.pow(euclidean_distance, 2) + + (1 - targets) * torch.pow(torch.clamp(margin - euclidean_distance, min=0), 2)) diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py index 2d9f7b3..631ba52 100644 --- a/fastreid/modeling/meta_arch/baseline.py +++ b/fastreid/modeling/meta_arch/baseline.py @@ -88,6 +88,10 @@ class Baseline(nn.Module): 'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN, 'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA, 'scale': cfg.MODEL.LOSSES.COSFACE.SCALE + }, + 'contrastive': { + 'margin': cfg.MODEL.LOSSES.CONTRASTIVE.MARGIN, + 'scale': cfg.MODEL.LOSSES.CONTRASTIVE.SCALE } } } @@ -143,7 +147,7 @@ class Baseline(nn.Module): # fmt: on # Log prediction accuracy - log_accuracy(pred_class_logits, gt_labels) + # log_accuracy(pred_class_logits, gt_labels) loss_dict = {} loss_names = self.loss_kwargs['loss_names'] @@ -185,4 +189,12 @@ class Baseline(nn.Module): cosface_kwargs.get('gamma'), ) * cosface_kwargs.get('scale') + if 'ContrastiveLoss' in loss_names: + contrastive_kwargs = self.loss_kwargs.get('contrastive') + loss_dict['loss_contrastive'] = contrastive_loss( + pred_features, + gt_labels, + contrastive_kwargs.get('margin') + ) * contrastive_kwargs.get('scale') + return loss_dict diff --git a/projects/FastShoe/configs/base-pair.yaml b/projects/FastShoe/configs/base-pair.yaml index 3b7db85..5edc725 100644 --- a/projects/FastShoe/configs/base-pair.yaml +++ b/projects/FastShoe/configs/base-pair.yaml @@ -18,10 +18,10 @@ MODEL: NUM_CLASSES: 2 LOSSES: - NAME: ("CrossEntropyLoss",) + NAME: ("ContrastiveLoss",) - CE: - EPSILON: 0.1 + CONTRASTIVE: + MARGIN: 2.0 SCALE: 1. INPUT: