mirror of https://github.com/JDAI-CV/fast-reid.git
Implement Sub-Center ArcFace
parent
5ac3be3b27
commit
46255a4641
|
@ -83,6 +83,7 @@ _C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax"
|
|||
# Margin and Scale for margin-based classification layer
|
||||
_C.MODEL.HEADS.MARGIN = 0.
|
||||
_C.MODEL.HEADS.SCALE = 1.
|
||||
_C.MODEL.HEADS.ARC_K = 1
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# REID LOSSES options
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.config import configurable
|
||||
|
@ -22,19 +23,25 @@ class ClasHead(EmbeddingHead):
|
|||
self,
|
||||
*,
|
||||
return_embedding=False,
|
||||
arc_k=1,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
NOTE: this interface is experimental.
|
||||
"""
|
||||
|
||||
super(ClasHead, self).__init__(**kwargs)
|
||||
self.num_classes = kwargs['num_classes']
|
||||
self.return_embedding = return_embedding
|
||||
self.arc_k = arc_k
|
||||
|
||||
if arc_k > 1:
|
||||
kwargs['num_classes'] = kwargs['num_classes'] * arc_k
|
||||
super(ClasHead, self).__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
config_dict = super(ClasHead, cls).from_config(cfg)
|
||||
config_dict['return_embedding'] = cfg.MODEL.HEADS.RETURN_EMBEDDING
|
||||
config_dict['arc_k'] = cfg.MODEL.HEADS.ARC_K
|
||||
return config_dict
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
|
@ -52,6 +59,9 @@ class ClasHead(EmbeddingHead):
|
|||
logits = F.linear(neck_feat, self.weight)
|
||||
else:
|
||||
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
||||
if self.arc_k > 1:
|
||||
logits = torch.reshape(logits, (-1, self.num_classes, self.arc_k))
|
||||
logits = torch.max(logits, dim=2)[0]
|
||||
|
||||
# Evaluation
|
||||
if not self.training: return logits.mul_(self.cls_layer.s)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
_BASE_: base-clas.yaml
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
PRETRAIN: True
|
||||
|
||||
HEADS:
|
||||
ARC_K: 3
|
||||
EMBEDDING_DIM: 256
|
||||
SCALE: 65.
|
||||
|
||||
|
||||
SOLVER:
|
||||
MAX_EPOCH: 50
|
||||
SCHED: MultiStepLR
|
||||
STEPS: [20, 35, 45]
|
||||
|
||||
OUTPUT_DIR: projects/Shoe/logs/subcenter-arc-k3
|
Loading…
Reference in New Issue