refactor(heads): move num_classes out from heads

set parameter num_classes in meta_arch to easily modify different heads fc layer
pull/46/head
liaoxingyu 2020-04-29 21:29:48 +08:00
parent 907798c8c9
commit 329764bb60
5 changed files with 13 additions and 16 deletions

View File

@ -11,9 +11,8 @@ from ...layers import *
@REID_HEADS_REGISTRY.register()
class BNneckHead(nn.Module):
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.pool_layer = nn.Sequential(
pool_layer,
@ -24,13 +23,13 @@ class BNneckHead(nn.Module):
# identity classification layer
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
self.classifier = Arcface(cfg, in_feat)
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
self.classifier = Circle(cfg, in_feat)
else:
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
def forward(self, features, targets=None):
"""

View File

@ -16,9 +16,9 @@ The call is expected to return an :class:`ROIHeads`.
"""
def build_reid_heads(cfg, in_feat, pool_layer):
def build_reid_heads(cfg, in_feat, num_classes, pool_layer):
"""
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
"""
head = cfg.MODEL.HEADS.NAME
return REID_HEADS_REGISTRY.get(head)(cfg, in_feat, pool_layer)
return REID_HEADS_REGISTRY.get(head)(cfg, in_feat, num_classes, pool_layer)

View File

@ -10,10 +10,8 @@ from ...layers import *
@REID_HEADS_REGISTRY.register()
class LinearHead(nn.Module):
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.pool_layer = nn.Sequential(
pool_layer,
@ -22,13 +20,13 @@ class LinearHead(nn.Module):
# identity classification layer
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
self.classifier = Arcface(cfg, in_feat)
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
self.classifier = Circle(cfg, in_feat)
else:
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
def forward(self, features, targets=None):
"""

View File

@ -11,9 +11,8 @@ from ...layers import *
@REID_HEADS_REGISTRY.register()
class ReductionHead(nn.Module):
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
self.pool_layer = nn.Sequential(
@ -34,13 +33,13 @@ class ReductionHead(nn.Module):
# identity classification layer
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
self.classifier = nn.Linear(reduction_dim, self._num_classes, bias=False)
self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
self.classifier = Arcface(cfg, reduction_dim)
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
self.classifier = Circle(cfg, reduction_dim)
else:
self.classifier = nn.Linear(reduction_dim, self._num_classes, bias=False)
self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
def forward(self, features, targets=None):
"""

View File

@ -32,7 +32,8 @@ class Baseline(nn.Module):
pool_layer = nn.Identity()
in_feat = cfg.MODEL.HEADS.IN_FEAT
self.heads = build_reid_heads(cfg, in_feat, pool_layer)
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
def forward(self, inputs):
images = inputs["images"]