mirror of https://github.com/open-mmlab/mmcv.git
Merge branch 'yl-1993-cnn' into cnn
commit
6243822627
|
@ -1,7 +1,9 @@
|
|||
from .alexnet import AlexNet
|
||||
from .vgg import VGG, make_vgg_layer
|
||||
from .resnet import ResNet, make_res_layer
|
||||
from .weight_init import xavier_init, normal_init, uniform_init, kaiming_init
|
||||
|
||||
__all__ = [
|
||||
'ResNet', 'make_res_layer', 'xavier_init', 'normal_init', 'uniform_init',
|
||||
'kaiming_init'
|
||||
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
|
||||
'xavier_init', 'normal_init', 'uniform_init', 'kaiming_init'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..runner import load_checkpoint
|
||||
|
||||
|
||||
class AlexNet(nn.Module):
|
||||
"""AlexNet backbone.
|
||||
|
||||
Args:
|
||||
num_classes (int): number of classes for classification.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=-1):
|
||||
super(AlexNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
)
|
||||
if self.num_classes > 0:
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(),
|
||||
nn.Linear(256 * 6 * 6, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, num_classes),
|
||||
)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
# use default initializer
|
||||
pass
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.features(x)
|
||||
if self.num_classes > 0:
|
||||
x = x.view(x.size(0), 256 * 6 * 6)
|
||||
x = self.classifier(x)
|
||||
|
||||
return x
|
|
@ -28,7 +28,8 @@ class BasicBlock(nn.Module):
|
|||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch'):
|
||||
style='pytorch',
|
||||
with_cp=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
|
@ -38,6 +39,7 @@ class BasicBlock(nn.Module):
|
|||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
assert not with_cp
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..runner import load_checkpoint
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, dilation=1, bias=False):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
|
||||
def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False):
|
||||
layers = []
|
||||
for _ in range(num_blocks):
|
||||
layers.append(conv3x3(inplanes, planes, dilation, not with_bn))
|
||||
if with_bn:
|
||||
layers.append(nn.BatchNorm2d(planes))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
inplanes = planes
|
||||
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class VGG(nn.Module):
|
||||
"""VGG backbone.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of vgg, from {11, 13, 16, 19}.
|
||||
with_bn (bool): Use BatchNorm or not.
|
||||
num_classes (int): number of classes for classification.
|
||||
num_stages (int): VGG stages, normally 5.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
|
||||
running stats (mean and var).
|
||||
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
11: (1, 1, 2, 2, 2),
|
||||
13: (2, 2, 2, 2, 2),
|
||||
16: (2, 2, 3, 3, 3),
|
||||
19: (2, 2, 4, 4, 4)
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
depth,
|
||||
with_bn=False,
|
||||
num_classes=-1,
|
||||
num_stages=5,
|
||||
dilations=(1, 1, 1, 1, 1),
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
frozen_stages=-1,
|
||||
bn_eval=True,
|
||||
bn_frozen=False):
|
||||
super(VGG, self).__init__()
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError('invalid depth {} for vgg'.format(depth))
|
||||
assert num_stages >= 1 and num_stages <= 5
|
||||
stage_blocks = self.arch_settings[depth]
|
||||
stage_blocks = stage_blocks[:num_stages]
|
||||
assert len(dilations) == num_stages
|
||||
assert max(out_indices) < num_stages
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.bn_eval = bn_eval
|
||||
self.bn_frozen = bn_frozen
|
||||
|
||||
self.inplanes = 3
|
||||
self.vgg_layers = []
|
||||
for i, num_blocks in enumerate(stage_blocks):
|
||||
dilation = dilations[i]
|
||||
planes = 64 * 2**i if i < 4 else 512
|
||||
vgg_layer = make_vgg_layer(
|
||||
self.inplanes,
|
||||
planes,
|
||||
num_blocks,
|
||||
dilation=dilation,
|
||||
with_bn=with_bn)
|
||||
self.inplanes = planes
|
||||
layer_name = 'layer{}'.format(i + 1)
|
||||
self.add_module(layer_name, vgg_layer)
|
||||
self.vgg_layers.append(layer_name)
|
||||
|
||||
if self.num_classes > 0:
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 7 * 7, 4096),
|
||||
nn.ReLU(True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, num_classes),
|
||||
)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.vgg_layers):
|
||||
vgg_layer = getattr(self, layer_name)
|
||||
x = vgg_layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
if self.num_classes > 0:
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
outs.append(x)
|
||||
if len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(VGG, self).train(mode)
|
||||
if self.bn_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
if self.bn_frozen:
|
||||
for params in m.parameters():
|
||||
params.requires_grad = False
|
||||
if mode and self.frozen_stages >= 0:
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
mod = getattr(self, 'layer{}'.format(i))
|
||||
mod.eval()
|
||||
for param in mod.parameters():
|
||||
param.requires_grad = False
|
Loading…
Reference in New Issue