fast-reid/projects/FastCls/fastcls/cls_head.py

36 lines
874 B
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from fastreid.modeling.heads import REID_HEADS_REGISTRY, EmbeddingHead
@REID_HEADS_REGISTRY.register()
class ClsHead(EmbeddingHead):
def forward(self, features, targets=None):
"""
See :class:`ClsHeads.forward`.
"""
global_feat = self.pool_layer(features)
bn_feat = self.bottleneck(global_feat)
bn_feat = bn_feat[..., 0, 0]
cls_outputs = self.classifier(bn_feat)
# Evaluation
# fmt: off
if not self.training: return cls_outputs
# fmt: on
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
return {
"cls_outputs": cls_outputs,
"pred_class_logits": pred_class_logits,
"features": bn_feat,
}