diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index d724ce0..538941b 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -83,6 +83,7 @@ _C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax" # Margin and Scale for margin-based classification layer _C.MODEL.HEADS.MARGIN = 0. _C.MODEL.HEADS.SCALE = 1. +_C.MODEL.HEADS.ARC_K = 1 # ---------------------------------------------------------------------------- # # REID LOSSES options diff --git a/fastreid/modeling/heads/clas_head.py b/fastreid/modeling/heads/clas_head.py index 37a6399..c11b697 100644 --- a/fastreid/modeling/heads/clas_head.py +++ b/fastreid/modeling/heads/clas_head.py @@ -4,6 +4,7 @@ @contact: sherlockliao01@gmail.com """ +import torch import torch.nn.functional as F from fastreid.config import configurable @@ -22,19 +23,25 @@ class ClasHead(EmbeddingHead): self, *, return_embedding=False, + arc_k=1, **kwargs ): """ NOTE: this interface is experimental. """ - - super(ClasHead, self).__init__(**kwargs) + self.num_classes = kwargs['num_classes'] self.return_embedding = return_embedding + self.arc_k = arc_k + + if arc_k > 1: + kwargs['num_classes'] = kwargs['num_classes'] * arc_k + super(ClasHead, self).__init__(**kwargs) @classmethod def from_config(cls, cfg): config_dict = super(ClasHead, cls).from_config(cfg) config_dict['return_embedding'] = cfg.MODEL.HEADS.RETURN_EMBEDDING + config_dict['arc_k'] = cfg.MODEL.HEADS.ARC_K return config_dict def forward(self, features, targets=None): @@ -52,6 +59,9 @@ class ClasHead(EmbeddingHead): logits = F.linear(neck_feat, self.weight) else: logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight)) + if self.arc_k > 1: + logits = torch.reshape(logits, (-1, self.num_classes, self.arc_k)) + logits = torch.max(logits, dim=2)[0] # Evaluation if not self.training: return logits.mul_(self.cls_layer.s) diff --git a/projects/Shoe/configs/subceter_arc_k3.yaml b/projects/Shoe/configs/subceter_arc_k3.yaml new file mode 100644 index 0000000..9e52730 --- /dev/null +++ b/projects/Shoe/configs/subceter_arc_k3.yaml @@ -0,0 +1,18 @@ +_BASE_: base-clas.yaml + +MODEL: + BACKBONE: + PRETRAIN: True + + HEADS: + ARC_K: 3 + EMBEDDING_DIM: 256 + SCALE: 65. + + +SOLVER: + MAX_EPOCH: 50 + SCHED: MultiStepLR + STEPS: [20, 35, 45] + +OUTPUT_DIR: projects/Shoe/logs/subcenter-arc-k3