make ClasHead can return embedding feature

pull/608/head
zuchen.wang 2021-11-24 20:11:58 +08:00
parent fc29be547c
commit 915cee240a
1 changed files with 29 additions and 0 deletions

View File

@ -6,11 +6,37 @@
import torch.nn.functional as F
from fastreid.config import configurable
from fastreid.modeling.heads import REID_HEADS_REGISTRY, EmbeddingHead
@REID_HEADS_REGISTRY.register()
class ClasHead(EmbeddingHead):
"""
Make ClasHead behavior like EmbeddingHead when eval, return embedding feat for cosine distance computation, such as
image retrieval
"""
@configurable
def __init__(
self,
*,
return_embedding=False,
**kwargs
):
"""
NOTE: this interface is experimental.
"""
super(ClasHead, self).__init__(**kwargs)
self.return_embedding = return_embedding
@classmethod
def from_config(cls, cfg):
config_dict = super(ClasHead, cls).from_config(cfg)
config_dict['return_embedding'] = cfg.MODEL.HEADS.RETURN_EMBEDDING
return config_dict
def forward(self, features, targets=None):
"""
See :class:`ClsHeads.forward`.
@ -19,6 +45,9 @@ class ClasHead(EmbeddingHead):
neck_feat = self.bottleneck(pool_feat)
neck_feat = neck_feat.view(neck_feat.size(0), -1)
if not self.training and self.return_embedding:
return neck_feat
if self.cls_layer.__class__.__name__ == 'Linear':
logits = F.linear(neck_feat, self.weight)
else: