mirror of https://github.com/JDAI-CV/fast-reid.git
update default config
parent
d4e2ac32d8
commit
6687df06e0
|
@ -82,10 +82,6 @@ _C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax"
|
|||
_C.MODEL.HEADS.MARGIN = 0.
|
||||
_C.MODEL.HEADS.SCALE = 1
|
||||
|
||||
# PCB HEAD options
|
||||
_C.MODEL.HEADS.FULL_DIM = 512
|
||||
_C.MODEL.HEADS.PART_DIM = 512
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# REID LOSSES options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
@ -156,6 +152,11 @@ _C.MODEL.PCB = CN()
|
|||
_C.MODEL.PCB.PART_NUM = 3
|
||||
_C.MODEL.PCB.EMBEDDING_DIM = 512
|
||||
|
||||
_C.MODEL.PCB.HEAD = CN()
|
||||
_C.MODEL.PCB.HEAD.FULL_DIM = 2048
|
||||
_C.MODEL.PCB.HEAD.PART_DIM = 512
|
||||
_C.MODEL.PCB.HEAD.EMBEDDING_DIM = 512
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# INPUT
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
@ -135,13 +135,11 @@ class PcbHead(nn.Module):
|
|||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: CfgNode):
|
||||
full_dim = cfg.MODEL.HEADS.FULL_DIM
|
||||
part_dim = cfg.MODEL.HEADS.PART_DIM
|
||||
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
|
||||
# num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
# cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
# scale = cfg.MODEL.HEADS.SCALE
|
||||
# margin = cfg.MODEL.HEADS.MARGIN
|
||||
# fmt: off
|
||||
full_dim = cfg.MODEL.PCB.HEAD.FULL_DIM
|
||||
part_dim = cfg.MODEL.PCB.HEAD.PART_DIM
|
||||
embedding_dim = cfg.MODEL.PCB.HEAD.EMBEDDING_DIM
|
||||
# fmt: on
|
||||
|
||||
return {
|
||||
'full_dim': full_dim,
|
||||
|
|
|
@ -27,6 +27,7 @@ class PCB(Baseline):
|
|||
pixel_mean,
|
||||
pixel_std,
|
||||
part_num,
|
||||
part_dim,
|
||||
embedding_dim,
|
||||
loss_kwargs=None
|
||||
):
|
||||
|
@ -48,6 +49,7 @@ class PCB(Baseline):
|
|||
loss_kwargs=loss_kwargs
|
||||
)
|
||||
self.part_num = part_num
|
||||
self.part_dim = part_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
self.modify_backbone()
|
||||
self.random_init()
|
||||
|
@ -77,12 +79,13 @@ class PCB(Baseline):
|
|||
# embedding
|
||||
for i in range(self.part_num):
|
||||
name = 'embedder' + str(i)
|
||||
setattr(self, name, nn.Linear(self.embedding_dim, 512))
|
||||
setattr(self, name, nn.Linear(self.embedding_dim, self.part_dim))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
config_dict = super(PCB, cls).from_config(cfg)
|
||||
config_dict['part_num'] = cfg.MODEL.PCB.PART_NUM
|
||||
config_dict['part_dim'] = cfg.MODEL.PCB.PART_DIM
|
||||
config_dict['embedding_dim'] = cfg.MODEL.PCB.EMBEDDING_DIM
|
||||
return config_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue