diff --git a/fastreid/modeling/heads/bnneck_head.py b/fastreid/modeling/heads/bnneck_head.py index ead208b..35fcaae 100644 --- a/fastreid/modeling/heads/bnneck_head.py +++ b/fastreid/modeling/heads/bnneck_head.py @@ -5,7 +5,7 @@ """ from fastreid.layers import * -from fastreid.utils.weight_init import weights_init_kaiming +from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier from .build import REID_HEADS_REGISTRY @@ -22,12 +22,14 @@ class BNneckHead(nn.Module): # identity classification layer if cfg.MODEL.HEADS.CLS_LAYER == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False) + self.classifier.apply(weights_init_classifier) elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface': self.classifier = Arcface(cfg, in_feat) elif cfg.MODEL.HEADS.CLS_LAYER == 'circle': self.classifier = Circle(cfg, in_feat) else: self.classifier = nn.Linear(in_feat, num_classes, bias=False) + self.classifier.apply(weights_init_classifier) def forward(self, features, targets=None): """