From 915cee240ab1d5dfb09d5f3322107e1b185239d3 Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Wed, 24 Nov 2021 20:11:58 +0800 Subject: [PATCH] make ClasHead can return embedding feature --- fastreid/modeling/heads/clas_head.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/fastreid/modeling/heads/clas_head.py b/fastreid/modeling/heads/clas_head.py index e154bfa..37a6399 100644 --- a/fastreid/modeling/heads/clas_head.py +++ b/fastreid/modeling/heads/clas_head.py @@ -6,11 +6,37 @@ import torch.nn.functional as F +from fastreid.config import configurable from fastreid.modeling.heads import REID_HEADS_REGISTRY, EmbeddingHead @REID_HEADS_REGISTRY.register() class ClasHead(EmbeddingHead): + """ + Make ClasHead behavior like EmbeddingHead when eval, return embedding feat for cosine distance computation, such as + image retrieval + """ + + @configurable + def __init__( + self, + *, + return_embedding=False, + **kwargs + ): + """ + NOTE: this interface is experimental. + """ + + super(ClasHead, self).__init__(**kwargs) + self.return_embedding = return_embedding + + @classmethod + def from_config(cls, cfg): + config_dict = super(ClasHead, cls).from_config(cfg) + config_dict['return_embedding'] = cfg.MODEL.HEADS.RETURN_EMBEDDING + return config_dict + def forward(self, features, targets=None): """ See :class:`ClsHeads.forward`. @@ -19,6 +45,9 @@ class ClasHead(EmbeddingHead): neck_feat = self.bottleneck(pool_feat) neck_feat = neck_feat.view(neck_feat.size(0), -1) + if not self.training and self.return_embedding: + return neck_feat + if self.cls_layer.__class__.__name__ == 'Linear': logits = F.linear(neck_feat, self.weight) else: