Use parameter default value to control default behavior of init_cfg in (#319)

`LinearClsHead` and `MultiLabelLinearClsHead`

And remove the verbose `_init_layers` method of `LinearClsHead` and
`MultiLabelLinearClsHead`.
pull/338/head
Ma Zerun 2021-06-30 19:13:27 +08:00 committed by GitHub
parent db5502f525
commit 19cfb25e5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 18 deletions

View File

@ -14,20 +14,16 @@ class LinearClsHead(ClsHead):
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
"""
def __init__(self,
num_classes,
in_channels,
init_cfg=None,
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
*args,
**kwargs):
init_cfg = init_cfg or dict(
type='Normal',
mean=0.,
std=0.01,
bias=0.,
override=dict(name='fc'))
super(LinearClsHead, self).__init__(init_cfg=init_cfg, *args, **kwargs)
self.in_channels = in_channels
@ -37,9 +33,6 @@ class LinearClsHead(ClsHead):
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
self._init_layers()
def _init_layers(self):
self.fc = nn.Linear(self.in_channels, self.num_classes)
def simple_test(self, img):

View File

@ -14,6 +14,8 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
loss (dict): Config of classification loss.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
"""
def __init__(self,
@ -24,12 +26,7 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
use_sigmoid=True,
reduction='mean',
loss_weight=1.0),
init_cfg=dict(
type='Normal',
mean=0.,
std=0.01,
bias=0.,
override=dict(name='fc'))):
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
super(MultiLabelLinearClsHead, self).__init__(
loss=loss, init_cfg=init_cfg)
@ -39,9 +36,7 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
self.in_channels = in_channels
self.num_classes = num_classes
self._init_layers()
def _init_layers(self):
self.fc = nn.Linear(self.in_channels, self.num_classes)
def forward_train(self, x, gt_label):