support bias in init_weights in ClsHead

pull/13/head
xiaohangzhan 2020-07-08 23:18:54 +08:00
parent a465185fc9
commit e4f09cecf5
1 changed files with 2 additions and 2 deletions

View File

@ -25,13 +25,13 @@ class ClsHead(nn.Module):
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc_cls = nn.Linear(in_channels, num_classes)
def init_weights(self, init_linear='normal', std=0.01):
def init_weights(self, init_linear='normal', std=0.01, bias=0.):
assert init_linear in ['normal', 'kaiming'], \
"Undefined init_linear: {}".format(init_linear)
for m in self.modules():
if isinstance(m, nn.Linear):
if init_linear == 'normal':
normal_init(m, std=std)
normal_init(m, std=std, bias=bias)
else:
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m,