Implement Sub-Center ArcFace

pull/608/head
zuchen.wang 2021-11-25 16:00:56 +08:00
parent 5ac3be3b27
commit 46255a4641
3 changed files with 31 additions and 2 deletions

View File

@ -83,6 +83,7 @@ _C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax"
# Margin and Scale for margin-based classification layer
_C.MODEL.HEADS.MARGIN = 0.
_C.MODEL.HEADS.SCALE = 1.
_C.MODEL.HEADS.ARC_K = 1
# ---------------------------------------------------------------------------- #
# REID LOSSES options

View File

@ -4,6 +4,7 @@
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from fastreid.config import configurable
@ -22,19 +23,25 @@ class ClasHead(EmbeddingHead):
self,
*,
return_embedding=False,
arc_k=1,
**kwargs
):
"""
NOTE: this interface is experimental.
"""
super(ClasHead, self).__init__(**kwargs)
self.num_classes = kwargs['num_classes']
self.return_embedding = return_embedding
self.arc_k = arc_k
if arc_k > 1:
kwargs['num_classes'] = kwargs['num_classes'] * arc_k
super(ClasHead, self).__init__(**kwargs)
@classmethod
def from_config(cls, cfg):
config_dict = super(ClasHead, cls).from_config(cfg)
config_dict['return_embedding'] = cfg.MODEL.HEADS.RETURN_EMBEDDING
config_dict['arc_k'] = cfg.MODEL.HEADS.ARC_K
return config_dict
def forward(self, features, targets=None):
@ -52,6 +59,9 @@ class ClasHead(EmbeddingHead):
logits = F.linear(neck_feat, self.weight)
else:
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
if self.arc_k > 1:
logits = torch.reshape(logits, (-1, self.num_classes, self.arc_k))
logits = torch.max(logits, dim=2)[0]
# Evaluation
if not self.training: return logits.mul_(self.cls_layer.s)

View File

@ -0,0 +1,18 @@
_BASE_: base-clas.yaml
MODEL:
BACKBONE:
PRETRAIN: True
HEADS:
ARC_K: 3
EMBEDDING_DIM: 256
SCALE: 65.
SOLVER:
MAX_EPOCH: 50
SCHED: MultiStepLR
STEPS: [20, 35, 45]
OUTPUT_DIR: projects/Shoe/logs/subcenter-arc-k3