update default config

pull/608/head
zuchen.wang 2021-10-27 15:24:25 +08:00
parent d4e2ac32d8
commit 6687df06e0
3 changed files with 15 additions and 13 deletions

View File

@ -82,10 +82,6 @@ _C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax"
_C.MODEL.HEADS.MARGIN = 0. _C.MODEL.HEADS.MARGIN = 0.
_C.MODEL.HEADS.SCALE = 1 _C.MODEL.HEADS.SCALE = 1
# PCB HEAD options
_C.MODEL.HEADS.FULL_DIM = 512
_C.MODEL.HEADS.PART_DIM = 512
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# REID LOSSES options # REID LOSSES options
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
@ -156,6 +152,11 @@ _C.MODEL.PCB = CN()
_C.MODEL.PCB.PART_NUM = 3 _C.MODEL.PCB.PART_NUM = 3
_C.MODEL.PCB.EMBEDDING_DIM = 512 _C.MODEL.PCB.EMBEDDING_DIM = 512
_C.MODEL.PCB.HEAD = CN()
_C.MODEL.PCB.HEAD.FULL_DIM = 2048
_C.MODEL.PCB.HEAD.PART_DIM = 512
_C.MODEL.PCB.HEAD.EMBEDDING_DIM = 512
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# INPUT # INPUT
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@ -135,13 +135,11 @@ class PcbHead(nn.Module):
@classmethod @classmethod
def from_config(cls, cfg: CfgNode): def from_config(cls, cfg: CfgNode):
full_dim = cfg.MODEL.HEADS.FULL_DIM # fmt: off
part_dim = cfg.MODEL.HEADS.PART_DIM full_dim = cfg.MODEL.PCB.HEAD.FULL_DIM
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM part_dim = cfg.MODEL.PCB.HEAD.PART_DIM
# num_classes = cfg.MODEL.HEADS.NUM_CLASSES embedding_dim = cfg.MODEL.PCB.HEAD.EMBEDDING_DIM
# cls_type = cfg.MODEL.HEADS.CLS_LAYER # fmt: on
# scale = cfg.MODEL.HEADS.SCALE
# margin = cfg.MODEL.HEADS.MARGIN
return { return {
'full_dim': full_dim, 'full_dim': full_dim,

View File

@ -27,6 +27,7 @@ class PCB(Baseline):
pixel_mean, pixel_mean,
pixel_std, pixel_std,
part_num, part_num,
part_dim,
embedding_dim, embedding_dim,
loss_kwargs=None loss_kwargs=None
): ):
@ -48,6 +49,7 @@ class PCB(Baseline):
loss_kwargs=loss_kwargs loss_kwargs=loss_kwargs
) )
self.part_num = part_num self.part_num = part_num
self.part_dim = part_dim
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.modify_backbone() self.modify_backbone()
self.random_init() self.random_init()
@ -77,12 +79,13 @@ class PCB(Baseline):
# embedding # embedding
for i in range(self.part_num): for i in range(self.part_num):
name = 'embedder' + str(i) name = 'embedder' + str(i)
setattr(self, name, nn.Linear(self.embedding_dim, 512)) setattr(self, name, nn.Linear(self.embedding_dim, self.part_dim))
@classmethod @classmethod
def from_config(cls, cfg): def from_config(cls, cfg):
config_dict = super(PCB, cls).from_config(cfg) config_dict = super(PCB, cls).from_config(cfg)
config_dict['part_num'] = cfg.MODEL.PCB.PART_NUM config_dict['part_num'] = cfg.MODEL.PCB.PART_NUM
config_dict['part_dim'] = cfg.MODEL.PCB.PART_DIM
config_dict['embedding_dim'] = cfg.MODEL.PCB.EMBEDDING_DIM config_dict['embedding_dim'] = cfg.MODEL.PCB.EMBEDDING_DIM
return config_dict return config_dict