[WIP] Refactoring weights initialization (#270)

* [WIP] Refactoring weights initialization

* fix lint and constant init cfg

* fix pretrained bug

* fix typo

* fix isort

* revise model utils
pull/298/head
Miao Zheng 2021-06-10 10:54:34 +08:00 committed by GitHub
parent 5066e32306
commit 4ca21c7d03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 229 additions and 178 deletions

View File

@ -1,38 +1,36 @@
import logging
from abc import ABCMeta, abstractmethod
import torch.nn as nn
from mmcv.runner import load_checkpoint
from mmcv.runner import BaseModule
class BaseBackbone(nn.Module, metaclass=ABCMeta):
class BaseBackbone(BaseModule, metaclass=ABCMeta):
"""Base backbone.
This class defines the basic functions of a backbone. Any backbone that
inherits this class should at least define its own `forward` function.
"""
def __init__(self):
super(BaseBackbone, self).__init__()
def __init__(self, init_cfg=None):
super(BaseBackbone, self).__init__(init_cfg)
def init_weights(self, pretrained=None):
"""Init backbone weights.
# def init_weights(self, pretrained=None):
# """Init backbone weights.
Args:
pretrained (str | None): If pretrained is a string, then it
initializes backbone weights by loading the pretrained
checkpoint. If pretrained is None, then it follows default
initializer or customized initializer in subclasses.
"""
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
# use default initializer or customized initializer in subclasses
pass
else:
raise TypeError('pretrained must be a str or None.'
f' But received {type(pretrained)}.')
# Args:
# pretrained (str | None): If pretrained is a string, then it
# initializes backbone weights by loading the pretrained
# checkpoint. If pretrained is None, then it follows default
# initializer or customized initializer in subclasses.
# """
# if isinstance(pretrained, str):
# logger = logging.getLogger()
# load_checkpoint(self, pretrained, strict=False, logger=logger)
# elif pretrained is None:
# # use default initializer or customized initializer in subclasses
# pass
# else:
# raise TypeError('pretrained must be a str or None.'
# f' But received {type(pretrained)}.')
@abstractmethod
def forward(self, x):

View File

@ -135,8 +135,15 @@ class MobileNetV2(BaseBackbone):
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
norm_eval=False,
with_cp=False):
super(MobileNetV2, self).__init__()
with_cp=False,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(MobileNetV2, self).__init__(init_cfg)
self.widen_factor = widen_factor
self.out_indices = out_indices
for index in out_indices:

View File

@ -1,8 +1,4 @@
import logging
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
@ -70,8 +66,12 @@ class MobileNetv3(BaseBackbone):
out_indices=(10, ),
frozen_stages=-1,
norm_eval=False,
with_cp=False):
super(MobileNetv3, self).__init__()
with_cp=False,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(type='Constant', val=1, layer=['BatchNorm2d'])
]):
super(MobileNetv3, self).__init__(init_cfg)
assert arch in self.arch_settings
for index in out_indices:
if index not in range(0, len(self.arch_settings[arch])):
@ -139,18 +139,18 @@ class MobileNetv3(BaseBackbone):
layers.append(layer_name)
return layers
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
# def init_weights(self, pretrained=None):
# if isinstance(pretrained, str):
# logger = logging.getLogger()
# load_checkpoint(self, pretrained, strict=False, logger=logger)
# elif pretrained is None:
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# kaiming_init(m)
# elif isinstance(m, nn.BatchNorm2d):
# constant_init(m, 1)
# else:
# raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)

View File

@ -97,8 +97,9 @@ class RegNet(ResNet):
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
zero_init_residual=True,
init_cfg=None):
super(ResNet, self).__init__(init_cfg)
# Generate RegNet parameters first
if isinstance(arch, str):

View File

@ -1,7 +1,7 @@
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
constant_init, kaiming_init)
constant_init)
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
@ -459,8 +459,15 @@ class ResNet(BaseBackbone):
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
zero_init_residual=True,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(ResNet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.depth = depth
@ -587,21 +594,22 @@ class ResNet(BaseBackbone):
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
super(ResNet, self).init_weights(pretrained)
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
# def init_weights(self, pretrained=None):
def init_weights(self):
super(ResNet, self).init_weights()
# if pretrained is None:
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# kaiming_init(m)
# elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
# constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
def forward(self, x):
if self.deep_stem:

View File

@ -1,11 +1,8 @@
import logging
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
normal_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import channel_shuffle, make_divisible
@ -184,8 +181,9 @@ class ShuffleNetV1(BaseBackbone):
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=False,
with_cp=False):
super(ShuffleNetV1, self).__init__()
with_cp=False,
init_cfg=None):
super(ShuffleNetV1, self).__init__(init_cfg)
self.stage_blocks = [4, 8, 4]
self.groups = groups
@ -250,25 +248,27 @@ class ShuffleNetV1(BaseBackbone):
for param in layer.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m.weight, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
else:
raise TypeError('pretrained must be a str or None. But received '
f'{type(pretrained)}')
# def init_weights(self, pretrained=None):
def init_weights(self):
super(ShuffleNetV1, self).init_weights()
# if isinstance(pretrained, str):
# logger = logging.getLogger()
# load_checkpoint(self, pretrained, strict=False, logger=logger)
# elif pretrained is None:
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m.weight, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
# else:
# raise TypeError('pretrained must be a str or None. But received '
# f'{type(pretrained)}')
def make_layer(self, out_channels, num_blocks, first_block=False):
"""Stack ShuffleUnit blocks to make a layer.

View File

@ -1,10 +1,7 @@
import logging
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, normal_init
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import channel_shuffle
@ -162,8 +159,9 @@ class ShuffleNetV2(BaseBackbone):
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=False,
with_cp=False):
super(ShuffleNetV2, self).__init__()
with_cp=False,
init_cfg=None):
super(ShuffleNetV2, self).__init__(init_cfg)
self.stage_blocks = [4, 8, 4]
for index in out_indices:
if index not in range(0, 4):
@ -255,25 +253,27 @@ class ShuffleNetV2(BaseBackbone):
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m.weight, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
else:
raise TypeError('pretrained must be a str or None. But received '
f'{type(pretrained)}')
# def init_weights(self, pretrained=None):
# if isinstance(pretrained, str):
# logger = logging.getLogger()
# load_checkpoint(self, pretrained, strict=False, logger=logger)
# elif pretrained is None:
def init_weighs(self):
super(ShuffleNetV2, self).init_weights()
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m.weight, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
# else:
# raise TypeError('pretrained must be a str or None. But received '
# f'{type(pretrained)}')
def forward(self, x):
x = self.conv1(x)

View File

@ -1,5 +1,5 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init
from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
@ -85,8 +85,13 @@ class VGG(BaseBackbone):
act_cfg=dict(type='ReLU'),
norm_eval=False,
ceil_mode=False,
with_last_pool=True):
super(VGG, self).__init__()
with_last_pool=True,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(type='Constant', val=1., layer=['_BatchNorm']),
dict(type='Normal', std=0.01, layer=['Linear'])
]):
super(VGG, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5
@ -144,16 +149,16 @@ class VGG(BaseBackbone):
nn.Linear(4096, num_classes),
)
def init_weights(self, pretrained=None):
super(VGG, self).init_weights(pretrained)
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
# def init_weights(self, pretrained=None):
# super(VGG, self).init_weights(pretrained)
# if pretrained is None:
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# kaiming_init(m)
# elif isinstance(m, _BatchNorm):
# constant_init(m, 1)
# elif isinstance(m, nn.Linear):
# normal_init(m, std=0.01)
def forward(self, x):
outs = []

View File

@ -455,11 +455,11 @@ class VisionTransformer(BaseBackbone):
self.init_weights()
def init_weights(self, pretrained=None):
super(VisionTransformer, self).init_weights(pretrained)
if pretrained is None:
# Modified from ClassyVision
nn.init.normal_(self.pos_embed, std=0.02)
def init_weights(self):
super(VisionTransformer, self).init_weights()
# if pretrained is None:
# # Modified from ClassyVision
nn.init.normal_(self.pos_embed, std=0.02)
@property
def norm1(self):

View File

@ -6,9 +6,8 @@ import cv2
import mmcv
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv import color_val
from mmcv.utils import print_log
from mmcv.runner import BaseModule
# TODO import `auto_fp16` from mmcv and delete them from mmcls
try:
@ -19,11 +18,11 @@ except ImportError:
from mmcls.core import auto_fp16
class BaseClassifier(nn.Module, metaclass=ABCMeta):
class BaseClassifier(BaseModule, metaclass=ABCMeta):
"""Base class for classifiers."""
def __init__(self):
super(BaseClassifier, self).__init__()
def __init__(self, init_cfg=None):
super(BaseClassifier, self).__init__(init_cfg)
self.fp16_enabled = False
@property
@ -57,9 +56,9 @@ class BaseClassifier(nn.Module, metaclass=ABCMeta):
def simple_test(self, img, **kwargs):
pass
def init_weights(self, pretrained=None):
if pretrained is not None:
print_log(f'load model from: {pretrained}', logger='root')
# def init_weights(self, pretrained=None):
# if pretrained is not None:
# print_log(f'load model from: {pretrained}', logger='root')
def forward_test(self, imgs, **kwargs):
"""

View File

@ -1,4 +1,4 @@
import torch.nn as nn
import warnings
from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck
from ..utils import BatchCutMixLayer, BatchMixupLayer
@ -13,8 +13,14 @@ class ImageClassifier(BaseClassifier):
neck=None,
head=None,
pretrained=None,
train_cfg=None):
super(ImageClassifier, self).__init__()
train_cfg=None,
init_cfg=None):
super(ImageClassifier, self).__init__(init_cfg)
if pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
self.backbone = build_backbone(backbone)
@ -35,19 +41,19 @@ class ImageClassifier(BaseClassifier):
if cutmix_cfg is not None:
self.cutmix = BatchCutMixLayer(**cutmix_cfg)
self.init_weights(pretrained=pretrained)
# self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
super(ImageClassifier, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
if self.with_head:
self.head.init_weights()
# def init_weights(self, pretrained=None):
# super(ImageClassifier, self).init_weights(pretrained)
# self.backbone.init_weights(pretrained=pretrained)
# if self.with_neck:
# if isinstance(self.neck, nn.Sequential):
# for m in self.neck:
# m.init_weights()
# else:
# self.neck.init_weights()
# if self.with_head:
# self.head.init_weights()
def extract_feat(self, img):
"""Directly extract features from the backbone + neck."""

View File

@ -1,16 +1,16 @@
from abc import ABCMeta, abstractmethod
import torch.nn as nn
from mmcv.runner import BaseModule
class BaseHead(nn.Module, metaclass=ABCMeta):
class BaseHead(BaseModule, metaclass=ABCMeta):
"""Base head."""
def __init__(self):
super(BaseHead, self).__init__()
def __init__(self, init_cfg=None):
super(BaseHead, self).__init__(init_cfg)
def init_weights(self):
pass
# def init_weights(self):
# pass
@abstractmethod
def forward_train(self, x, gt_label, **kwargs):

View File

@ -21,8 +21,9 @@ class ClsHead(BaseHead):
def __init__(self,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
cal_acc=False):
super(ClsHead, self).__init__()
cal_acc=False,
init_cfg=None):
super(ClsHead, self).__init__(init_cfg=init_cfg)
assert isinstance(loss, dict)
assert isinstance(topk, (int, tuple))

View File

@ -1,7 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from ..builder import HEADS
from .cls_head import ClsHead
@ -19,6 +18,13 @@ class LinearClsHead(ClsHead):
def __init__(self, num_classes, in_channels, *args, **kwargs):
super(LinearClsHead, self).__init__(*args, **kwargs)
self.init_cfg = dict(
type='Normal',
mean=0.,
std=0.01,
bias=0.,
override=dict(name='fc'))
self.in_channels = in_channels
self.num_classes = num_classes
@ -31,8 +37,8 @@ class LinearClsHead(ClsHead):
def _init_layers(self):
self.fc = nn.Linear(self.in_channels, self.num_classes)
def init_weights(self):
normal_init(self.fc, mean=0, std=0.01, bias=0)
# def init_weights(self):
# normal_init(self.fc, mean=0, std=0.01, bias=0)
def simple_test(self, img):
"""Test without augmentation."""

View File

@ -18,8 +18,9 @@ class MultiLabelClsHead(BaseHead):
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0)):
super(MultiLabelClsHead, self).__init__()
loss_weight=1.0),
init_cfg=None):
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
assert isinstance(loss, dict)

View File

@ -24,8 +24,15 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0)):
super(MultiLabelLinearClsHead, self).__init__(loss=loss)
loss_weight=1.0),
init_cfg=dict(
type='Normal',
mean=0.,
std=0.01,
bias=0.,
override=dict(name='fc'))):
super(MultiLabelLinearClsHead, self).__init__(
loss=loss, init_cfg=init_cfg)
if num_classes <= 0:
raise ValueError(

View File

@ -1,11 +1,12 @@
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from .se_layer import SELayer
class InvertedResidual(nn.Module):
# class InvertedResidual(nn.Module):
class InvertedResidual(BaseModule):
"""Inverted Residual Block.
Args:
@ -44,8 +45,9 @@ class InvertedResidual(nn.Module):
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super(InvertedResidual, self).__init__()
with_cp=False,
init_cfg=None):
super(InvertedResidual, self).__init__(init_cfg)
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
assert stride in [1, 2]
self.with_cp = with_cp

View File

@ -1,9 +1,11 @@
import mmcv
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
class SELayer(nn.Module):
# class SELayer(nn.Module):
class SELayer(BaseModule):
"""Squeeze-and-Excitation Module.
Args:
@ -24,8 +26,9 @@ class SELayer(nn.Module):
channels,
ratio=16,
conv_cfg=None,
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
super(SELayer, self).__init__()
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
init_cfg=None):
super(SELayer, self).__init__(init_cfg)
if isinstance(act_cfg, dict):
act_cfg = (act_cfg, act_cfg)
assert len(act_cfg) == 2

View File

@ -404,8 +404,11 @@ def test_resnet():
assert check_norm_state(model.modules(), False)
# Test ResNet50 with torchvision pretrained weight
model = ResNet(depth=50, norm_eval=True)
model.init_weights('torchvision://resnet50')
model = ResNet(
depth=50,
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)

View File

@ -153,8 +153,11 @@ def test_seresnet():
assert check_norm_state(model.modules(), False)
# Test SEResNet50 with torchvision pretrained weight
model = SEResNet(depth=50, norm_eval=True)
model.init_weights('torchvision://resnet50')
model = SEResNet(
depth=50,
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)

View File

@ -126,6 +126,7 @@ def main():
meta['seed'] = args.seed
model = build_classifier(cfg.model)
model.init_weights()
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2: