mirror of https://github.com/JDAI-CV/fast-reid.git
Implement Sub-Center ArcFace
parent
5ac3be3b27
commit
46255a4641
fastreid
config
modeling/heads
projects/Shoe/configs
|
@ -83,6 +83,7 @@ _C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax"
|
||||||
# Margin and Scale for margin-based classification layer
|
# Margin and Scale for margin-based classification layer
|
||||||
_C.MODEL.HEADS.MARGIN = 0.
|
_C.MODEL.HEADS.MARGIN = 0.
|
||||||
_C.MODEL.HEADS.SCALE = 1.
|
_C.MODEL.HEADS.SCALE = 1.
|
||||||
|
_C.MODEL.HEADS.ARC_K = 1
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# REID LOSSES options
|
# REID LOSSES options
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from fastreid.config import configurable
|
from fastreid.config import configurable
|
||||||
|
@ -22,19 +23,25 @@ class ClasHead(EmbeddingHead):
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
return_embedding=False,
|
return_embedding=False,
|
||||||
|
arc_k=1,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
NOTE: this interface is experimental.
|
NOTE: this interface is experimental.
|
||||||
"""
|
"""
|
||||||
|
self.num_classes = kwargs['num_classes']
|
||||||
super(ClasHead, self).__init__(**kwargs)
|
|
||||||
self.return_embedding = return_embedding
|
self.return_embedding = return_embedding
|
||||||
|
self.arc_k = arc_k
|
||||||
|
|
||||||
|
if arc_k > 1:
|
||||||
|
kwargs['num_classes'] = kwargs['num_classes'] * arc_k
|
||||||
|
super(ClasHead, self).__init__(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, cfg):
|
def from_config(cls, cfg):
|
||||||
config_dict = super(ClasHead, cls).from_config(cfg)
|
config_dict = super(ClasHead, cls).from_config(cfg)
|
||||||
config_dict['return_embedding'] = cfg.MODEL.HEADS.RETURN_EMBEDDING
|
config_dict['return_embedding'] = cfg.MODEL.HEADS.RETURN_EMBEDDING
|
||||||
|
config_dict['arc_k'] = cfg.MODEL.HEADS.ARC_K
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
def forward(self, features, targets=None):
|
def forward(self, features, targets=None):
|
||||||
|
@ -52,6 +59,9 @@ class ClasHead(EmbeddingHead):
|
||||||
logits = F.linear(neck_feat, self.weight)
|
logits = F.linear(neck_feat, self.weight)
|
||||||
else:
|
else:
|
||||||
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
||||||
|
if self.arc_k > 1:
|
||||||
|
logits = torch.reshape(logits, (-1, self.num_classes, self.arc_k))
|
||||||
|
logits = torch.max(logits, dim=2)[0]
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if not self.training: return logits.mul_(self.cls_layer.s)
|
if not self.training: return logits.mul_(self.cls_layer.s)
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
_BASE_: base-clas.yaml
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
PRETRAIN: True
|
||||||
|
|
||||||
|
HEADS:
|
||||||
|
ARC_K: 3
|
||||||
|
EMBEDDING_DIM: 256
|
||||||
|
SCALE: 65.
|
||||||
|
|
||||||
|
|
||||||
|
SOLVER:
|
||||||
|
MAX_EPOCH: 50
|
||||||
|
SCHED: MultiStepLR
|
||||||
|
STEPS: [20, 35, 45]
|
||||||
|
|
||||||
|
OUTPUT_DIR: projects/Shoe/logs/subcenter-arc-k3
|
Loading…
Reference in New Issue