complete the docstring & import from mmcv & range check for frozen_stages and out_indices
parent
b01a00c0c8
commit
05e5173b6b
|
@ -4,14 +4,30 @@ from collections import OrderedDict
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn.weight_init import constant_init, kaiming_init
|
||||
from mmcv.runner import load_checkpoint
|
||||
|
||||
from .base_backbone import BaseBackbone
|
||||
from .weight_init import constant_init, kaiming_init
|
||||
|
||||
|
||||
def conv3x3(inplanes, planes, stride=1, padding=1, bias=False, groups=1):
|
||||
"""3x3 convolution with padding
|
||||
"""3x3 convolution
|
||||
|
||||
Applies a 2D convolution with the kernel_size 3x3 over an input signal
|
||||
composed of several input planes.
|
||||
|
||||
Args:
|
||||
inplanes (int): Number of channels of the feature maps.
|
||||
planes (int): Number of channels produced by the convolution.
|
||||
stride (int or tuple, optional): Stride of the convolution.
|
||||
Default is 1.
|
||||
padding (int): Controls the amount of implicit zero-paddings on both
|
||||
sides for padding number of points for each dimension.
|
||||
Default is 1.
|
||||
bias (bool, optional): Whether to add a learnable bias to the output.
|
||||
Default is True
|
||||
groups (int, optional): Number of blocked connections from input
|
||||
channels to output channels. Default is 1
|
||||
"""
|
||||
return nn.Conv2d(
|
||||
inplanes,
|
||||
|
@ -24,33 +40,61 @@ def conv3x3(inplanes, planes, stride=1, padding=1, bias=False, groups=1):
|
|||
|
||||
|
||||
def conv1x1(inplanes, planes, groups=1):
|
||||
"""1x1 convolution with padding
|
||||
- Normal pointwise convolution when groups == 1
|
||||
- Grouped pointwise convolution when groups > 1
|
||||
"""1x1 convolution
|
||||
|
||||
Applies a 2D convolution with the kernel_size 1x1 over an input signal
|
||||
composed of several input planes.
|
||||
|
||||
Args:
|
||||
inplanes (int): Number of channels of the input feature maps.
|
||||
planes (int): Number of channels produced by the convolution.
|
||||
groups (int, optional): Number of blocked connections from input
|
||||
channels to output channels. Default: 1
|
||||
"""
|
||||
return nn.Conv2d(inplanes, planes, kernel_size=1, groups=groups, stride=1)
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
batchsize, num_channels, height, width = x.data.size()
|
||||
assert (num_channels % groups == 0)
|
||||
""" Channel Shuffle operation
|
||||
|
||||
This function enable cross-group information flow for multiple group
|
||||
convolution layers.
|
||||
|
||||
Args:
|
||||
x: The input tensor.
|
||||
groups (int): The number of groups to divide the input tensor
|
||||
in channel dimension.
|
||||
|
||||
Returns:
|
||||
x: The output tensor after channel shuffle operation.
|
||||
|
||||
"""
|
||||
batchsize, num_channels, height, width = x.size()
|
||||
assert (num_channels % groups == 0), 'num_channels should ' \
|
||||
'be divisible by groups'
|
||||
channels_per_group = num_channels // groups
|
||||
|
||||
# reshape
|
||||
x = x.view(batchsize, groups, channels_per_group, height, width)
|
||||
|
||||
# transpose
|
||||
# - contiguous() required if transpose() is used before view().
|
||||
# See https://github.com/pytorch/pytorch/issues/764
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
|
||||
# flatten
|
||||
x = x.view(batchsize, -1, height, width)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
""" Make divisible function
|
||||
|
||||
This function ensures that all layers have a channel number that is
|
||||
divisible by divisor.
|
||||
|
||||
Args:
|
||||
v (int): The original channel number
|
||||
divisor (int): The divisor to fully divide the channel number
|
||||
min_value (int, optional): the minimum value of the output channel.
|
||||
|
||||
Returns:
|
||||
new_v (int): The modified output channel number
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
|
@ -61,6 +105,26 @@ def _make_divisible(v, divisor, min_value=None):
|
|||
|
||||
|
||||
class ShuffleUnit(nn.Module):
|
||||
"""ShuffleUnit block.
|
||||
|
||||
ShuffleNet unit with pointwise group convolution (GConv) and channel
|
||||
shuffle.
|
||||
|
||||
Args:
|
||||
inplanes (int): The input channels of the ShuffleUnit.
|
||||
planes (int): The output channels of the ShuffleUnit.
|
||||
groups (int, optional): The number of groups to be used in grouped 1x1
|
||||
convolutions in each ShuffleUnit.
|
||||
first_block (bool, optional): Whether is the first ShuffleUnit of a
|
||||
sequential ShuffleUnits. If True, use the grouped 1x1 convolution.
|
||||
combine (str, optional): The ways to combine the input and output
|
||||
branches.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
|
||||
Returns:
|
||||
out: output tensor
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
|
@ -69,7 +133,6 @@ class ShuffleUnit(nn.Module):
|
|||
first_block=True,
|
||||
combine='add',
|
||||
with_cp=False):
|
||||
|
||||
super(ShuffleUnit, self).__init__()
|
||||
self.inplanes = inplanes
|
||||
self.planes = planes
|
||||
|
@ -82,18 +145,15 @@ class ShuffleUnit(nn.Module):
|
|||
if self.combine == 'add':
|
||||
self.depthwise_stride = 1
|
||||
self._combine_func = self._add
|
||||
assert inplanes == planes, 'inplanes must be equal to ' \
|
||||
'planes when combine is add.'
|
||||
elif self.combine == 'concat':
|
||||
self.depthwise_stride = 2
|
||||
self._combine_func = self._concat
|
||||
self.planes -= self.inplanes
|
||||
else:
|
||||
raise ValueError("Cannot combine tensors with \"{}\" "
|
||||
"Only \"add\" and \"concat\" are "
|
||||
"supported".format(self.combine))
|
||||
|
||||
if combine == 'add':
|
||||
assert inplanes == planes, \
|
||||
'inplanes must be equal to outplanes when combine is add'
|
||||
raise ValueError(f'Cannot combine tensors with {self.combine}. '
|
||||
f'Only "add" and "concat" are supported.')
|
||||
|
||||
self.first_1x1_groups = self.groups if first_block else 1
|
||||
self.g_conv_1x1_compress = self._make_grouped_conv1x1(
|
||||
|
@ -108,15 +168,16 @@ class ShuffleUnit(nn.Module):
|
|||
self.bottleneck_channels,
|
||||
stride=self.depthwise_stride,
|
||||
groups=self.bottleneck_channels)
|
||||
self.bn_after_depthwise = \
|
||||
nn.BatchNorm2d(self.bottleneck_channels)
|
||||
|
||||
self.g_conv_1x1_expand = \
|
||||
self._make_grouped_conv1x1(self.bottleneck_channels,
|
||||
self.planes,
|
||||
self.groups,
|
||||
batch_norm=True,
|
||||
relu=False)
|
||||
self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)
|
||||
|
||||
self.g_conv_1x1_expand = self._make_grouped_conv1x1(
|
||||
self.bottleneck_channels,
|
||||
self.planes,
|
||||
self.groups,
|
||||
batch_norm=True,
|
||||
relu=False)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
|
@ -183,10 +244,11 @@ class ShuffleNetv1(BaseBackbone):
|
|||
"""ShuffleNetv1 backbone.
|
||||
|
||||
Args:
|
||||
groups (int): number of groups to be used in grouped
|
||||
1x1 convolutions in each ShuffleUnit. Default is 3 for best
|
||||
performance according to original paper.
|
||||
widen_factor (float): Config of widen_factor.
|
||||
groups (int, optional): The number of groups to be used in grouped 1x1
|
||||
convolutions in each ShuffleUnit. Default is 3 for best performance
|
||||
according to original paper.
|
||||
widen_factor (float, optional): Width multiplier - adjusts number of
|
||||
channels in each layer by this amount. Default is 1.0.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
|
@ -208,6 +270,15 @@ class ShuffleNetv1(BaseBackbone):
|
|||
super(ShuffleNetv1, self).__init__()
|
||||
blocks = [3, 7, 3]
|
||||
self.groups = groups
|
||||
|
||||
for indice in out_indices:
|
||||
if indice not in range(0, 4):
|
||||
raise ValueError(f'the item in out_indices must in '
|
||||
f'range(0, 4). But received {indice}')
|
||||
|
||||
if frozen_stages not in [-1, 1, 2, 3]:
|
||||
raise ValueError(f'frozen_stages must in [-1, 1, 2, 3]. '
|
||||
f'But received {frozen_stages}')
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.bn_eval = bn_eval
|
||||
|
@ -225,8 +296,9 @@ class ShuffleNetv1(BaseBackbone):
|
|||
elif groups == 8:
|
||||
channels = [384, 768, 1536]
|
||||
else:
|
||||
raise ValueError("{} groups is not supported for "
|
||||
"1x1 Grouped Convolutions".format(groups))
|
||||
raise ValueError(f'{groups} groups is not supported for 1x1 '
|
||||
f'Grouped Convolutions')
|
||||
|
||||
channels = [_make_divisible(ch * widen_factor, 8) for ch in channels]
|
||||
|
||||
self.inplanes = int(24 * widen_factor)
|
||||
|
@ -252,6 +324,20 @@ class ShuffleNetv1(BaseBackbone):
|
|||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def _make_layer(self, outplanes, blocks, first_block=True, with_cp=False):
|
||||
""" Stack n bottleneck modules where n is inferred from the depth of
|
||||
the network.
|
||||
|
||||
Args:
|
||||
outplanes: number of output channels
|
||||
blocks: number of blocks to be built
|
||||
first_block (bool, optional): Whether is the first ShuffleUnit of a
|
||||
sequential ShuffleUnits. If True, use the grouped 1x1
|
||||
convolution.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
|
||||
Returns: a Module consisting of n sequential ShuffleUnits.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(blocks):
|
||||
if i == 0:
|
||||
|
@ -272,6 +358,7 @@ class ShuffleNetv1(BaseBackbone):
|
|||
first_block=True,
|
||||
combine='add',
|
||||
with_cp=with_cp))
|
||||
|
||||
self.inplanes = outplanes
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
@ -311,7 +398,7 @@ class ShuffleNetv1(BaseBackbone):
|
|||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
mod = getattr(self, 'layer{}'.format(i))
|
||||
mod.eval()
|
||||
for param in mod.parameters():
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def constant_init(module, val, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.constant_(module.weight, val)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if distribution == 'uniform':
|
||||
nn.init.xavier_uniform_(module.weight, gain=gain)
|
||||
else:
|
||||
nn.init.xavier_normal_(module.weight, gain=gain)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def normal_init(module, mean=0, std=1, bias=0):
|
||||
nn.init.normal_(module.weight, mean, std)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def uniform_init(module, a=0, b=1, bias=0):
|
||||
nn.init.uniform_(module.weight, a, b)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def kaiming_init(module,
|
||||
a=0,
|
||||
mode='fan_out',
|
||||
nonlinearity='relu',
|
||||
bias=0,
|
||||
distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if distribution == 'uniform':
|
||||
nn.init.kaiming_uniform_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
else:
|
||||
nn.init.kaiming_normal_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def caffe2_xavier_init(module, bias=0):
|
||||
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
|
||||
# Acknowledgment to FAIR's internal code
|
||||
kaiming_init(
|
||||
module,
|
||||
a=1,
|
||||
mode='fan_in',
|
||||
nonlinearity='leaky_relu',
|
||||
distribution='uniform')
|
||||
|
||||
|
||||
def bias_init_with_prob(prior_prob):
|
||||
""" initialize conv/fc bias value according to giving probablity"""
|
||||
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
||||
return bias_init
|
|
@ -37,7 +37,7 @@ def test_shufflenetv1_shuffleuint():
|
|||
ShuffleUnit(24, 16, groups=3, first_block=True, combine='test')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# in_channels must be divisible by groups
|
||||
# inplanes must be divisible by groups
|
||||
ShuffleUnit(64, 64, groups=3, first_block=True, combine='add')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -66,6 +66,14 @@ def test_shufflenetv1_shuffleuint():
|
|||
|
||||
def test_shufflenetv1_backbone():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# frozen_stages must in [-1, 1, 2, 3]
|
||||
ShuffleNetv1(frozen_stages=10)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# the item in out_indices must in [0, 1, 2, 3]
|
||||
ShuffleNetv1(out_indices=[5])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# groups must in [1, 2, 3, 4, 8]
|
||||
ShuffleNetv1(groups=10)
|
||||
|
|
Loading…
Reference in New Issue