[WIP] Refactoring weights initialization (#270)
* [WIP] Refactoring weights initialization * fix lint and constant init cfg * fix pretrained bug * fix typo * fix isort * revise model utilspull/298/head
parent
5066e32306
commit
4ca21c7d03
|
@ -1,38 +1,36 @@
|
|||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
|
||||
class BaseBackbone(nn.Module, metaclass=ABCMeta):
|
||||
class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
||||
"""Base backbone.
|
||||
|
||||
This class defines the basic functions of a backbone. Any backbone that
|
||||
inherits this class should at least define its own `forward` function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BaseBackbone, self).__init__()
|
||||
def __init__(self, init_cfg=None):
|
||||
super(BaseBackbone, self).__init__(init_cfg)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Init backbone weights.
|
||||
# def init_weights(self, pretrained=None):
|
||||
# """Init backbone weights.
|
||||
|
||||
Args:
|
||||
pretrained (str | None): If pretrained is a string, then it
|
||||
initializes backbone weights by loading the pretrained
|
||||
checkpoint. If pretrained is None, then it follows default
|
||||
initializer or customized initializer in subclasses.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
# use default initializer or customized initializer in subclasses
|
||||
pass
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None.'
|
||||
f' But received {type(pretrained)}.')
|
||||
# Args:
|
||||
# pretrained (str | None): If pretrained is a string, then it
|
||||
# initializes backbone weights by loading the pretrained
|
||||
# checkpoint. If pretrained is None, then it follows default
|
||||
# initializer or customized initializer in subclasses.
|
||||
# """
|
||||
# if isinstance(pretrained, str):
|
||||
# logger = logging.getLogger()
|
||||
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
# elif pretrained is None:
|
||||
# # use default initializer or customized initializer in subclasses
|
||||
# pass
|
||||
# else:
|
||||
# raise TypeError('pretrained must be a str or None.'
|
||||
# f' But received {type(pretrained)}.')
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x):
|
||||
|
|
|
@ -135,8 +135,15 @@ class MobileNetV2(BaseBackbone):
|
|||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(MobileNetV2, self).__init__()
|
||||
with_cp=False,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]):
|
||||
super(MobileNetV2, self).__init__(init_cfg)
|
||||
self.widen_factor = widen_factor
|
||||
self.out_indices = out_indices
|
||||
for index in out_indices:
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.cnn import ConvModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
|
@ -70,8 +66,12 @@ class MobileNetv3(BaseBackbone):
|
|||
out_indices=(10, ),
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(MobileNetv3, self).__init__()
|
||||
with_cp=False,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(type='Constant', val=1, layer=['BatchNorm2d'])
|
||||
]):
|
||||
super(MobileNetv3, self).__init__(init_cfg)
|
||||
assert arch in self.arch_settings
|
||||
for index in out_indices:
|
||||
if index not in range(0, len(self.arch_settings[arch])):
|
||||
|
@ -139,18 +139,18 @@ class MobileNetv3(BaseBackbone):
|
|||
layers.append(layer_name)
|
||||
return layers
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
# def init_weights(self, pretrained=None):
|
||||
# if isinstance(pretrained, str):
|
||||
# logger = logging.getLogger()
|
||||
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
# elif pretrained is None:
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# kaiming_init(m)
|
||||
# elif isinstance(m, nn.BatchNorm2d):
|
||||
# constant_init(m, 1)
|
||||
# else:
|
||||
# raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
|
|
@ -97,8 +97,9 @@ class RegNet(ResNet):
|
|||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=True):
|
||||
super(ResNet, self).__init__()
|
||||
zero_init_residual=True,
|
||||
init_cfg=None):
|
||||
super(ResNet, self).__init__(init_cfg)
|
||||
|
||||
# Generate RegNet parameters first
|
||||
if isinstance(arch, str):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
||||
constant_init, kaiming_init)
|
||||
constant_init)
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
|
@ -459,8 +459,15 @@ class ResNet(BaseBackbone):
|
|||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=True):
|
||||
super(ResNet, self).__init__()
|
||||
zero_init_residual=True,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]):
|
||||
super(ResNet, self).__init__(init_cfg)
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
self.depth = depth
|
||||
|
@ -587,21 +594,22 @@ class ResNet(BaseBackbone):
|
|||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
super(ResNet, self).init_weights(pretrained)
|
||||
if pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
# def init_weights(self, pretrained=None):
|
||||
def init_weights(self):
|
||||
super(ResNet, self).init_weights()
|
||||
# if pretrained is None:
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# kaiming_init(m)
|
||||
# elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
# constant_init(m, 1)
|
||||
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
constant_init(m.norm3, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
constant_init(m.norm2, 0)
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
constant_init(m.norm3, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
constant_init(m.norm2, 0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.deep_stem:
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
|
||||
normal_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import channel_shuffle, make_divisible
|
||||
|
@ -184,8 +181,9 @@ class ShuffleNetV1(BaseBackbone):
|
|||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(ShuffleNetV1, self).__init__()
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super(ShuffleNetV1, self).__init__(init_cfg)
|
||||
self.stage_blocks = [4, 8, 4]
|
||||
self.groups = groups
|
||||
|
||||
|
@ -250,25 +248,27 @@ class ShuffleNetV1(BaseBackbone):
|
|||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if 'conv1' in name:
|
||||
normal_init(m, mean=0, std=0.01)
|
||||
else:
|
||||
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m.weight, val=1, bias=0.0001)
|
||||
if isinstance(m, _BatchNorm):
|
||||
if m.running_mean is not None:
|
||||
nn.init.constant_(m.running_mean, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None. But received '
|
||||
f'{type(pretrained)}')
|
||||
# def init_weights(self, pretrained=None):
|
||||
def init_weights(self):
|
||||
super(ShuffleNetV1, self).init_weights()
|
||||
# if isinstance(pretrained, str):
|
||||
# logger = logging.getLogger()
|
||||
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
# elif pretrained is None:
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if 'conv1' in name:
|
||||
normal_init(m, mean=0, std=0.01)
|
||||
else:
|
||||
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m.weight, val=1, bias=0.0001)
|
||||
if isinstance(m, _BatchNorm):
|
||||
if m.running_mean is not None:
|
||||
nn.init.constant_(m.running_mean, 0)
|
||||
# else:
|
||||
# raise TypeError('pretrained must be a str or None. But received '
|
||||
# f'{type(pretrained)}')
|
||||
|
||||
def make_layer(self, out_channels, num_blocks, first_block=False):
|
||||
"""Stack ShuffleUnit blocks to make a layer.
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, constant_init, normal_init
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import channel_shuffle
|
||||
|
@ -162,8 +159,9 @@ class ShuffleNetV2(BaseBackbone):
|
|||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super(ShuffleNetV2, self).__init__(init_cfg)
|
||||
self.stage_blocks = [4, 8, 4]
|
||||
for index in out_indices:
|
||||
if index not in range(0, 4):
|
||||
|
@ -255,25 +253,27 @@ class ShuffleNetV2(BaseBackbone):
|
|||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if 'conv1' in name:
|
||||
normal_init(m, mean=0, std=0.01)
|
||||
else:
|
||||
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m.weight, val=1, bias=0.0001)
|
||||
if isinstance(m, _BatchNorm):
|
||||
if m.running_mean is not None:
|
||||
nn.init.constant_(m.running_mean, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None. But received '
|
||||
f'{type(pretrained)}')
|
||||
# def init_weights(self, pretrained=None):
|
||||
# if isinstance(pretrained, str):
|
||||
# logger = logging.getLogger()
|
||||
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
# elif pretrained is None:
|
||||
def init_weighs(self):
|
||||
super(ShuffleNetV2, self).init_weights()
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if 'conv1' in name:
|
||||
normal_init(m, mean=0, std=0.01)
|
||||
else:
|
||||
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m.weight, val=1, bias=0.0001)
|
||||
if isinstance(m, _BatchNorm):
|
||||
if m.running_mean is not None:
|
||||
nn.init.constant_(m.running_mean, 0)
|
||||
# else:
|
||||
# raise TypeError('pretrained must be a str or None. But received '
|
||||
# f'{type(pretrained)}')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
|
@ -85,8 +85,13 @@ class VGG(BaseBackbone):
|
|||
act_cfg=dict(type='ReLU'),
|
||||
norm_eval=False,
|
||||
ceil_mode=False,
|
||||
with_last_pool=True):
|
||||
super(VGG, self).__init__()
|
||||
with_last_pool=True,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(type='Constant', val=1., layer=['_BatchNorm']),
|
||||
dict(type='Normal', std=0.01, layer=['Linear'])
|
||||
]):
|
||||
super(VGG, self).__init__(init_cfg)
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for vgg')
|
||||
assert num_stages >= 1 and num_stages <= 5
|
||||
|
@ -144,16 +149,16 @@ class VGG(BaseBackbone):
|
|||
nn.Linear(4096, num_classes),
|
||||
)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
super(VGG, self).init_weights(pretrained)
|
||||
if pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, _BatchNorm):
|
||||
constant_init(m, 1)
|
||||
elif isinstance(m, nn.Linear):
|
||||
normal_init(m, std=0.01)
|
||||
# def init_weights(self, pretrained=None):
|
||||
# super(VGG, self).init_weights(pretrained)
|
||||
# if pretrained is None:
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# kaiming_init(m)
|
||||
# elif isinstance(m, _BatchNorm):
|
||||
# constant_init(m, 1)
|
||||
# elif isinstance(m, nn.Linear):
|
||||
# normal_init(m, std=0.01)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
|
|
|
@ -455,11 +455,11 @@ class VisionTransformer(BaseBackbone):
|
|||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
super(VisionTransformer, self).init_weights(pretrained)
|
||||
if pretrained is None:
|
||||
# Modified from ClassyVision
|
||||
nn.init.normal_(self.pos_embed, std=0.02)
|
||||
def init_weights(self):
|
||||
super(VisionTransformer, self).init_weights()
|
||||
# if pretrained is None:
|
||||
# # Modified from ClassyVision
|
||||
nn.init.normal_(self.pos_embed, std=0.02)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
|
|
|
@ -6,9 +6,8 @@ import cv2
|
|||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from mmcv import color_val
|
||||
from mmcv.utils import print_log
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
# TODO import `auto_fp16` from mmcv and delete them from mmcls
|
||||
try:
|
||||
|
@ -19,11 +18,11 @@ except ImportError:
|
|||
from mmcls.core import auto_fp16
|
||||
|
||||
|
||||
class BaseClassifier(nn.Module, metaclass=ABCMeta):
|
||||
class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for classifiers."""
|
||||
|
||||
def __init__(self):
|
||||
super(BaseClassifier, self).__init__()
|
||||
def __init__(self, init_cfg=None):
|
||||
super(BaseClassifier, self).__init__(init_cfg)
|
||||
self.fp16_enabled = False
|
||||
|
||||
@property
|
||||
|
@ -57,9 +56,9 @@ class BaseClassifier(nn.Module, metaclass=ABCMeta):
|
|||
def simple_test(self, img, **kwargs):
|
||||
pass
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if pretrained is not None:
|
||||
print_log(f'load model from: {pretrained}', logger='root')
|
||||
# def init_weights(self, pretrained=None):
|
||||
# if pretrained is not None:
|
||||
# print_log(f'load model from: {pretrained}', logger='root')
|
||||
|
||||
def forward_test(self, imgs, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import torch.nn as nn
|
||||
import warnings
|
||||
|
||||
from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck
|
||||
from ..utils import BatchCutMixLayer, BatchMixupLayer
|
||||
|
@ -13,8 +13,14 @@ class ImageClassifier(BaseClassifier):
|
|||
neck=None,
|
||||
head=None,
|
||||
pretrained=None,
|
||||
train_cfg=None):
|
||||
super(ImageClassifier, self).__init__()
|
||||
train_cfg=None,
|
||||
init_cfg=None):
|
||||
super(ImageClassifier, self).__init__(init_cfg)
|
||||
|
||||
if pretrained is not None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated \
|
||||
key, please consider using init_cfg')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
self.backbone = build_backbone(backbone)
|
||||
|
||||
|
@ -35,19 +41,19 @@ class ImageClassifier(BaseClassifier):
|
|||
if cutmix_cfg is not None:
|
||||
self.cutmix = BatchCutMixLayer(**cutmix_cfg)
|
||||
|
||||
self.init_weights(pretrained=pretrained)
|
||||
# self.init_weights(pretrained=pretrained)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
super(ImageClassifier, self).init_weights(pretrained)
|
||||
self.backbone.init_weights(pretrained=pretrained)
|
||||
if self.with_neck:
|
||||
if isinstance(self.neck, nn.Sequential):
|
||||
for m in self.neck:
|
||||
m.init_weights()
|
||||
else:
|
||||
self.neck.init_weights()
|
||||
if self.with_head:
|
||||
self.head.init_weights()
|
||||
# def init_weights(self, pretrained=None):
|
||||
# super(ImageClassifier, self).init_weights(pretrained)
|
||||
# self.backbone.init_weights(pretrained=pretrained)
|
||||
# if self.with_neck:
|
||||
# if isinstance(self.neck, nn.Sequential):
|
||||
# for m in self.neck:
|
||||
# m.init_weights()
|
||||
# else:
|
||||
# self.neck.init_weights()
|
||||
# if self.with_head:
|
||||
# self.head.init_weights()
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Directly extract features from the backbone + neck."""
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
|
||||
class BaseHead(nn.Module, metaclass=ABCMeta):
|
||||
class BaseHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base head."""
|
||||
|
||||
def __init__(self):
|
||||
super(BaseHead, self).__init__()
|
||||
def __init__(self, init_cfg=None):
|
||||
super(BaseHead, self).__init__(init_cfg)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
# def init_weights(self):
|
||||
# pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
|
|
|
@ -21,8 +21,9 @@ class ClsHead(BaseHead):
|
|||
def __init__(self,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, ),
|
||||
cal_acc=False):
|
||||
super(ClsHead, self).__init__()
|
||||
cal_acc=False,
|
||||
init_cfg=None):
|
||||
super(ClsHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
assert isinstance(loss, dict)
|
||||
assert isinstance(topk, (int, tuple))
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from ..builder import HEADS
|
||||
from .cls_head import ClsHead
|
||||
|
@ -19,6 +18,13 @@ class LinearClsHead(ClsHead):
|
|||
|
||||
def __init__(self, num_classes, in_channels, *args, **kwargs):
|
||||
super(LinearClsHead, self).__init__(*args, **kwargs)
|
||||
self.init_cfg = dict(
|
||||
type='Normal',
|
||||
mean=0.,
|
||||
std=0.01,
|
||||
bias=0.,
|
||||
override=dict(name='fc'))
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
|
||||
|
@ -31,8 +37,8 @@ class LinearClsHead(ClsHead):
|
|||
def _init_layers(self):
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.fc, mean=0, std=0.01, bias=0)
|
||||
# def init_weights(self):
|
||||
# normal_init(self.fc, mean=0, std=0.01, bias=0)
|
||||
|
||||
def simple_test(self, img):
|
||||
"""Test without augmentation."""
|
||||
|
|
|
@ -18,8 +18,9 @@ class MultiLabelClsHead(BaseHead):
|
|||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)):
|
||||
super(MultiLabelClsHead, self).__init__()
|
||||
loss_weight=1.0),
|
||||
init_cfg=None):
|
||||
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
assert isinstance(loss, dict)
|
||||
|
||||
|
|
|
@ -24,8 +24,15 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)):
|
||||
super(MultiLabelLinearClsHead, self).__init__(loss=loss)
|
||||
loss_weight=1.0),
|
||||
init_cfg=dict(
|
||||
type='Normal',
|
||||
mean=0.,
|
||||
std=0.01,
|
||||
bias=0.,
|
||||
override=dict(name='fc'))):
|
||||
super(MultiLabelLinearClsHead, self).__init__(
|
||||
loss=loss, init_cfg=init_cfg)
|
||||
|
||||
if num_classes <= 0:
|
||||
raise ValueError(
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from .se_layer import SELayer
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
# class InvertedResidual(nn.Module):
|
||||
class InvertedResidual(BaseModule):
|
||||
"""Inverted Residual Block.
|
||||
|
||||
Args:
|
||||
|
@ -44,8 +45,9 @@ class InvertedResidual(nn.Module):
|
|||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_cp=False):
|
||||
super(InvertedResidual, self).__init__()
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super(InvertedResidual, self).__init__(init_cfg)
|
||||
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
|
||||
assert stride in [1, 2]
|
||||
self.with_cp = with_cp
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import mmcv
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
# class SELayer(nn.Module):
|
||||
class SELayer(BaseModule):
|
||||
"""Squeeze-and-Excitation Module.
|
||||
|
||||
Args:
|
||||
|
@ -24,8 +26,9 @@ class SELayer(nn.Module):
|
|||
channels,
|
||||
ratio=16,
|
||||
conv_cfg=None,
|
||||
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
|
||||
super(SELayer, self).__init__()
|
||||
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
|
||||
init_cfg=None):
|
||||
super(SELayer, self).__init__(init_cfg)
|
||||
if isinstance(act_cfg, dict):
|
||||
act_cfg = (act_cfg, act_cfg)
|
||||
assert len(act_cfg) == 2
|
||||
|
|
|
@ -404,8 +404,11 @@ def test_resnet():
|
|||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ResNet50 with torchvision pretrained weight
|
||||
model = ResNet(depth=50, norm_eval=True)
|
||||
model.init_weights('torchvision://resnet50')
|
||||
model = ResNet(
|
||||
depth=50,
|
||||
norm_eval=True,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
|
|
|
@ -153,8 +153,11 @@ def test_seresnet():
|
|||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test SEResNet50 with torchvision pretrained weight
|
||||
model = SEResNet(depth=50, norm_eval=True)
|
||||
model.init_weights('torchvision://resnet50')
|
||||
model = SEResNet(
|
||||
depth=50,
|
||||
norm_eval=True,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
|
|
|
@ -126,6 +126,7 @@ def main():
|
|||
meta['seed'] = args.seed
|
||||
|
||||
model = build_classifier(cfg.model)
|
||||
model.init_weights()
|
||||
|
||||
datasets = [build_dataset(cfg.data.train)]
|
||||
if len(cfg.workflow) == 2:
|
||||
|
|
Loading…
Reference in New Issue