mirror of https://github.com/JDAI-CV/fast-reid.git
refactor($modeling/meta): refactor heads output
without intermediate variables generated by reid heads, make it more flexiblepull/43/head
parent
e3ae03cc58
commit
3984f0c91d
|
@ -48,8 +48,11 @@ _C.MODEL.BACKBONE.PRETRAIN_PATH = ''
|
|||
_C.MODEL.HEADS = CN()
|
||||
_C.MODEL.HEADS.NAME = "BNneckHead"
|
||||
|
||||
# Number of identity
|
||||
_C.MODEL.HEADS.NUM_CLASSES = 751
|
||||
# Reduction dimension
|
||||
# Input feature dimension
|
||||
_C.MODEL.HEADS.IN_FEAT = 2048
|
||||
# Reduction dimension in head
|
||||
_C.MODEL.HEADS.REDUCTION_DIM = 512
|
||||
# Pooling layer type
|
||||
_C.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
||||
|
|
|
@ -31,7 +31,9 @@ class Baseline(nn.Module):
|
|||
pool_layer = GeneralizedMeanPoolingP()
|
||||
else:
|
||||
pool_layer = nn.Identity()
|
||||
self.heads = build_reid_heads(cfg, 2048, pool_layer)
|
||||
|
||||
in_feat = cfg.MODEL.HEADS.IN_FEAT
|
||||
self.heads = build_reid_heads(cfg, in_feat, pool_layer)
|
||||
|
||||
def forward(self, inputs):
|
||||
images = inputs["images"]
|
||||
|
@ -43,8 +45,7 @@ class Baseline(nn.Module):
|
|||
|
||||
# training
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
logits, global_feat = self.heads(features, targets)
|
||||
return logits, global_feat, targets
|
||||
return self.heads(features, targets)
|
||||
|
||||
def inference(self, images):
|
||||
assert not self.training
|
||||
|
|
Loading…
Reference in New Issue