Change the init_weight for shufflenet models
parent
6968ad5b3b
commit
c1d0090700
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
|
||||
kaiming_init)
|
||||
normal_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
@ -255,11 +255,17 @@ class ShuffleNetV1(BaseBackbone):
|
|||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
if 'conv1' in name:
|
||||
normal_init(m, mean=0, std=0.01)
|
||||
else:
|
||||
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
constant_init(m.weight, val=1, bias=0.0001)
|
||||
if isinstance(m, _BatchNorm):
|
||||
if m.running_mean is not None:
|
||||
nn.init.constant_(m.running_mean, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None. But received '
|
||||
f'{type(pretrained)}')
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||
from mmcv.cnn import ConvModule, constant_init, normal_init
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
@ -260,11 +260,18 @@ class ShuffleNetV2(BaseBackbone):
|
|||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
for name, m in self.named_modules():
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if 'conv1' in name:
|
||||
normal_init(m, mean=0, std=0.01)
|
||||
else:
|
||||
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m.weight, val=1, bias=0.0001)
|
||||
if isinstance(m, _BatchNorm):
|
||||
if m.running_mean is not None:
|
||||
nn.init.constant_(m.running_mean, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None. But received '
|
||||
f'{type(pretrained)}')
|
||||
|
|
Loading…
Reference in New Issue