mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
fix: add linear initial method
This commit is contained in:
parent
95e8a02b2a
commit
c21de64166
@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
from fastreid.layers import *
|
||||
from fastreid.utils.weight_init import weights_init_kaiming
|
||||
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
|
||||
|
||||
@ -22,12 +22,14 @@ class BNneckHead(nn.Module):
|
||||
# identity classification layer
|
||||
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
|
||||
self.classifier = Arcface(cfg, in_feat)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
|
||||
self.classifier = Circle(cfg, in_feat)
|
||||
else:
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user