From c1d009070023d043cd61014786ca00f8e628b1a3 Mon Sep 17 00:00:00 2001 From: lixiaojie Date: Wed, 8 Jul 2020 23:55:45 +0800 Subject: [PATCH] Change the init_weight for shufflenet models --- mmcls/models/backbones/shufflenet_v1.py | 14 ++++++++++---- mmcls/models/backbones/shufflenet_v2.py | 19 +++++++++++++------ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/mmcls/models/backbones/shufflenet_v1.py b/mmcls/models/backbones/shufflenet_v1.py index 3e3e76fe..aafc20d8 100644 --- a/mmcls/models/backbones/shufflenet_v1.py +++ b/mmcls/models/backbones/shufflenet_v1.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import (ConvModule, build_activation_layer, constant_init, - kaiming_init) + normal_init) from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm @@ -255,11 +255,17 @@ class ShuffleNetV1(BaseBackbone): logger = logging.getLogger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: - for m in self.modules(): + for name, m in self.named_modules(): if isinstance(m, nn.Conv2d): - kaiming_init(m) + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) + constant_init(m.weight, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) else: raise TypeError('pretrained must be a str or None. But received ' f'{type(pretrained)}') diff --git a/mmcls/models/backbones/shufflenet_v2.py b/mmcls/models/backbones/shufflenet_v2.py index 83252608..f3b251d1 100644 --- a/mmcls/models/backbones/shufflenet_v2.py +++ b/mmcls/models/backbones/shufflenet_v2.py @@ -3,7 +3,7 @@ import logging import torch import torch.nn as nn import torch.utils.checkpoint as cp -from mmcv.cnn import ConvModule, constant_init, kaiming_init +from mmcv.cnn import ConvModule, constant_init, normal_init from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm @@ -260,11 +260,18 @@ class ShuffleNetV2(BaseBackbone): logger = logging.getLogger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) + for name, m in self.named_modules(): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m.weight, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) else: raise TypeError('pretrained must be a str or None. But received ' f'{type(pretrained)}')