From 867d5a500391c4aa993b53145045cc9f56e1312b Mon Sep 17 00:00:00 2001 From: yl-1993 Date: Tue, 9 Oct 2018 21:08:03 +0800 Subject: [PATCH 1/6] consistent param list of block --- mmcv/cnn/resnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmcv/cnn/resnet.py b/mmcv/cnn/resnet.py index 6e1c62863..794206776 100644 --- a/mmcv/cnn/resnet.py +++ b/mmcv/cnn/resnet.py @@ -28,7 +28,8 @@ class BasicBlock(nn.Module): stride=1, dilation=1, downsample=None, - style='pytorch'): + style='pytorch', + with_cp=False): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.bn1 = nn.BatchNorm2d(planes) @@ -38,6 +39,7 @@ class BasicBlock(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation + assert not with_cp def forward(self, x): residual = x From 64959bd7725b91e137a9a3712ab2e8661f5ae01e Mon Sep 17 00:00:00 2001 From: yl-1993 Date: Tue, 9 Oct 2018 21:14:09 +0800 Subject: [PATCH 2/6] add alexnet & vgg --- mmcv/cnn/__init__.py | 6 +- mmcv/cnn/alexnet.py | 64 +++++++++++++++++ mmcv/cnn/vgg.py | 162 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 mmcv/cnn/alexnet.py create mode 100644 mmcv/cnn/vgg.py diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index 2be57468e..ba1d6bac4 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -1,3 +1,7 @@ +from .alexnet import AlexNet +from .vgg import VGG, make_vgg_layer from .resnet import ResNet, make_res_layer -__all__ = ['ResNet', 'make_res_layer'] +__all__ = ['AlexNet', + 'VGG', 'make_vgg_layer', + 'ResNet', 'make_res_layer'] diff --git a/mmcv/cnn/alexnet.py b/mmcv/cnn/alexnet.py new file mode 100644 index 000000000..e2e6825e1 --- /dev/null +++ b/mmcv/cnn/alexnet.py @@ -0,0 +1,64 @@ +import logging +import math + +import torch.nn as nn + +from ..runner import load_checkpoint + + +class AlexNet(nn.Module): + """AlexNet backbone. + + Args: + num_classes (int): number of classes for classification. + """ + + + def __init__(self, + num_classes=-1): + super(AlexNet, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), # caffe has dropout + nn.Linear(4096, num_classes), + ) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + # use default initializer + pass + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + + return x diff --git a/mmcv/cnn/vgg.py b/mmcv/cnn/vgg.py new file mode 100644 index 000000000..30a2f1497 --- /dev/null +++ b/mmcv/cnn/vgg.py @@ -0,0 +1,162 @@ +import logging +import math + +import torch.nn as nn + +from ..runner import load_checkpoint + + +def conv3x3(in_planes, out_planes, dilation=1, bias=False): + "3x3 convolution with padding" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=bias) + + +def make_vgg_layer(inplanes, + planes, + num_blocks, + dilation=1, + with_bn=False): + layers = [] + for _ in range(num_blocks): + layers.append( + conv3x3(inplanes, planes, dilation, not with_bn)) + if with_bn: + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + inplanes = planes + layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + + return nn.Sequential(*layers) + + +class VGG(nn.Module): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_bn (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze + running stats (mean and var). + bn_frozen (bool): Whether to freeze weight and bias of BN layers. + """ + + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + with_bn=False, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=(0, 1, 2, 3, 4), + frozen_stages=-1, + bn_eval=True, + bn_frozen=False): + super(VGG, self).__init__() + if depth not in self.arch_settings: + raise KeyError('invalid depth {} for vgg'.format(depth)) + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + assert max(out_indices) < num_stages + + self.num_classes = num_classes + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.bn_eval = bn_eval + self.bn_frozen = bn_frozen + + self.inplanes = 3 + self.vgg_layers = [] + for i, num_blocks in enumerate(stage_blocks): + dilation = dilations[i] + planes = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.inplanes, + planes, + num_blocks, + dilation=dilation, + with_bn=with_bn) + self.inplanes = planes + layer_name = 'layer{}'.format(i + 1) + self.add_module(layer_name, vgg_layer) + self.vgg_layers.append(layer_name) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + for i, layer_name in enumerate(self.vgg_layers): + vgg_layer = getattr(self, layer_name) + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode=True): + super(VGG, self).train(mode) + if self.bn_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + if self.bn_frozen: + for params in m.parameters(): + params.requires_grad = False + if mode and self.frozen_stages >= 0: + for i in range(1, self.frozen_stages + 1): + mod = getattr(self, 'layer{}'.format(i)) + mod.eval() + for param in mod.parameters(): + param.requires_grad = False From 8332ddbcbe0f7cc8ab2f5ed01253cedcc32b16ee Mon Sep 17 00:00:00 2001 From: yl-1993 Date: Tue, 9 Oct 2018 21:20:13 +0800 Subject: [PATCH 3/6] rm unused --- mmcv/cnn/alexnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/cnn/alexnet.py b/mmcv/cnn/alexnet.py index e2e6825e1..dc3bdc3bf 100644 --- a/mmcv/cnn/alexnet.py +++ b/mmcv/cnn/alexnet.py @@ -1,5 +1,4 @@ import logging -import math import torch.nn as nn @@ -40,7 +39,7 @@ class AlexNet(nn.Module): nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), - nn.ReLU(inplace=True), # caffe has dropout + nn.ReLU(inplace=True), nn.Linear(4096, num_classes), ) From f63f2a974881e54ae23b276bd50a4a116574d50a Mon Sep 17 00:00:00 2001 From: yl-1993 Date: Tue, 9 Oct 2018 22:41:18 +0800 Subject: [PATCH 4/6] fix flake8 error --- mmcv/cnn/alexnet.py | 4 +--- mmcv/cnn/vgg.py | 7 +++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mmcv/cnn/alexnet.py b/mmcv/cnn/alexnet.py index dc3bdc3bf..1230ee575 100644 --- a/mmcv/cnn/alexnet.py +++ b/mmcv/cnn/alexnet.py @@ -12,9 +12,7 @@ class AlexNet(nn.Module): num_classes (int): number of classes for classification. """ - - def __init__(self, - num_classes=-1): + def __init__(self, num_classes=-1): super(AlexNet, self).__init__() self.num_classes = num_classes self.features = nn.Sequential( diff --git a/mmcv/cnn/vgg.py b/mmcv/cnn/vgg.py index 30a2f1497..7d4700854 100644 --- a/mmcv/cnn/vgg.py +++ b/mmcv/cnn/vgg.py @@ -1,5 +1,4 @@ import logging -import math import torch.nn as nn @@ -118,7 +117,10 @@ class VGG(nn.Module): elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_( + m.weight, + mode='fan_out', + nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): @@ -131,6 +133,7 @@ class VGG(nn.Module): raise TypeError('pretrained must be a str or None') def forward(self, x): + outs = [] for i, layer_name in enumerate(self.vgg_layers): vgg_layer = getattr(self, layer_name) x = vgg_layer(x) From 64afdd0afa1fd3f56c44c0e130ba9053055aa77b Mon Sep 17 00:00:00 2001 From: yl-1993 Date: Tue, 9 Oct 2018 23:12:54 +0800 Subject: [PATCH 5/6] put newline as last character --- mmcv/cnn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index e283d4295..a329f3225 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -7,4 +7,4 @@ __all__ = [ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', 'xavier_init', 'normal_init', 'uniform_init', 'kaiming_init' -] \ No newline at end of file +] From 4b2abfcf245bc1c4b8de12245e0fb7a8b301bd2b Mon Sep 17 00:00:00 2001 From: Kai Chen Date: Tue, 9 Oct 2018 23:46:36 +0800 Subject: [PATCH 6/6] minor format adjustment --- mmcv/cnn/__init__.py | 5 ++--- mmcv/cnn/vgg.py | 23 ++++++++--------------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index a329f3225..9684ee3b0 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -4,7 +4,6 @@ from .resnet import ResNet, make_res_layer from .weight_init import xavier_init, normal_init, uniform_init, kaiming_init __all__ = [ - 'AlexNet', 'VGG', 'make_vgg_layer', - 'ResNet', 'make_res_layer', 'xavier_init', 'normal_init', 'uniform_init', - 'kaiming_init' + 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', + 'xavier_init', 'normal_init', 'uniform_init', 'kaiming_init' ] diff --git a/mmcv/cnn/vgg.py b/mmcv/cnn/vgg.py index 7d4700854..f5939e20e 100644 --- a/mmcv/cnn/vgg.py +++ b/mmcv/cnn/vgg.py @@ -16,15 +16,10 @@ def conv3x3(in_planes, out_planes, dilation=1, bias=False): bias=bias) -def make_vgg_layer(inplanes, - planes, - num_blocks, - dilation=1, - with_bn=False): +def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False): layers = [] for _ in range(num_blocks): - layers.append( - conv3x3(inplanes, planes, dilation, not with_bn)) + layers.append(conv3x3(inplanes, planes, dilation, not with_bn)) if with_bn: layers.append(nn.BatchNorm2d(planes)) layers.append(nn.ReLU(inplace=True)) @@ -89,11 +84,11 @@ class VGG(nn.Module): dilation = dilations[i] planes = 64 * 2**i if i < 4 else 512 vgg_layer = make_vgg_layer( - self.inplanes, - planes, - num_blocks, - dilation=dilation, - with_bn=with_bn) + self.inplanes, + planes, + num_blocks, + dilation=dilation, + with_bn=with_bn) self.inplanes = planes layer_name = 'layer{}'.format(i + 1) self.add_module(layer_name, vgg_layer) @@ -118,9 +113,7 @@ class VGG(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( - m.weight, - mode='fan_out', - nonlinearity='relu') + m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d):