Change the init_weight for shufflenet models

pull/2/head
lixiaojie 2020-07-08 23:55:45 +08:00 committed by yl-1993
parent 6968ad5b3b
commit c1d0090700
2 changed files with 23 additions and 10 deletions

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
kaiming_init)
normal_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
@ -255,11 +255,17 @@ class ShuffleNetV1(BaseBackbone):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
constant_init(m.weight, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
else:
raise TypeError('pretrained must be a str or None. But received '
f'{type(pretrained)}')

View File

@ -3,7 +3,7 @@ import logging
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.cnn import ConvModule, constant_init, normal_init
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
@ -260,11 +260,18 @@ class ShuffleNetV2(BaseBackbone):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
for name, m in self.named_modules():
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m.weight, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
else:
raise TypeError('pretrained must be a str or None. But received '
f'{type(pretrained)}')