Update resnext init
parent
2295cf56c2
commit
321435e6b4
|
@ -80,11 +80,10 @@ class ResNeXt(nn.Module):
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
m.weight.data.fill_(1)
|
nn.init.constant_(m.weight, 1.)
|
||||||
m.bias.data.zero_()
|
nn.init.constant_(m.bias, 0.)
|
||||||
|
|
||||||
def _make_layer(self, block, planes, blocks, stride=1):
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
downsample = None
|
downsample = None
|
||||||
|
|
Loading…
Reference in New Issue