support bias in init_weights in ClsHead
parent
a465185fc9
commit
e4f09cecf5
|
@ -25,13 +25,13 @@ class ClsHead(nn.Module):
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
self.fc_cls = nn.Linear(in_channels, num_classes)
|
self.fc_cls = nn.Linear(in_channels, num_classes)
|
||||||
|
|
||||||
def init_weights(self, init_linear='normal', std=0.01):
|
def init_weights(self, init_linear='normal', std=0.01, bias=0.):
|
||||||
assert init_linear in ['normal', 'kaiming'], \
|
assert init_linear in ['normal', 'kaiming'], \
|
||||||
"Undefined init_linear: {}".format(init_linear)
|
"Undefined init_linear: {}".format(init_linear)
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
if init_linear == 'normal':
|
if init_linear == 'normal':
|
||||||
normal_init(m, std=std)
|
normal_init(m, std=std, bias=bias)
|
||||||
else:
|
else:
|
||||||
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||||||
elif isinstance(m,
|
elif isinstance(m,
|
||||||
|
|
Loading…
Reference in New Issue