Use parameter default value to control default behavior of init_cfg in (#319)
`LinearClsHead` and `MultiLabelLinearClsHead` And remove the verbose `_init_layers` method of `LinearClsHead` and `MultiLabelLinearClsHead`.pull/338/head
parent
db5502f525
commit
19cfb25e5e
|
@ -14,20 +14,16 @@ class LinearClsHead(ClsHead):
|
|||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
init_cfg (dict | optional): The extra init config of layers.
|
||||
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
init_cfg=None,
|
||||
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
|
||||
*args,
|
||||
**kwargs):
|
||||
init_cfg = init_cfg or dict(
|
||||
type='Normal',
|
||||
mean=0.,
|
||||
std=0.01,
|
||||
bias=0.,
|
||||
override=dict(name='fc'))
|
||||
super(LinearClsHead, self).__init__(init_cfg=init_cfg, *args, **kwargs)
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
@ -37,9 +33,6 @@ class LinearClsHead(ClsHead):
|
|||
raise ValueError(
|
||||
f'num_classes={num_classes} must be a positive integer')
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def simple_test(self, img):
|
||||
|
|
|
@ -14,6 +14,8 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
loss (dict): Config of classification loss.
|
||||
init_cfg (dict | optional): The extra init config of layers.
|
||||
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -24,12 +26,7 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
init_cfg=dict(
|
||||
type='Normal',
|
||||
mean=0.,
|
||||
std=0.01,
|
||||
bias=0.,
|
||||
override=dict(name='fc'))):
|
||||
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
|
||||
super(MultiLabelLinearClsHead, self).__init__(
|
||||
loss=loss, init_cfg=init_cfg)
|
||||
|
||||
|
@ -39,9 +36,7 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
|
|
Loading…
Reference in New Issue