mmpretrain/mmcls/models/backbones/shufflenet_v2.py

329 lines
11 KiB
Python

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from torch.nn.modules.batchnorm import _BatchNorm
from .base_backbone import BaseBackbone
def channel_shuffle(x, groups):
""" Channel Shuffle operation.
This function enables cross-group information flow for multiple group
convolution layers.
Args:
x (Tensor): The input tensor.
groups (int): The number of groups to divide the input tensor
in channel dimension.
Returns:
Tensor: The output tensor after channel shuffle operation.
"""
batchsize, num_channels, height, width = x.size()
assert (num_channels % groups == 0), ('num_channels should be '
'divisible by groups')
channels_per_group = num_channels // groups
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, height, width)
return x
def make_divisible(value, divisor, min_value=None):
""" Make divisible function.
This function ensures that all layers have a channel number that is
divisible by divisor.
Args:
value (int): The original channel number.
divisor (int): The divisor to fully divide the channel number.
min_value (int, optional): the minimum value of the output channel.
Returns:
int: The modified output channel number
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_value < 0.9 * value:
new_value += divisor
return new_value
class InvertedResidual(nn.Module):
"""InvertedResidual block for ShuffleNetV2 backbone.
Args:
inplanes (int): The input channels of the block.
planes (int): The output channels of the block.
stride (int): stride of the 3x3 convolution layer. Default: 1
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
inplanes,
planes,
stride=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super(InvertedResidual, self).__init__()
self.stride = stride
self.with_cp = with_cp
branch_features = planes // 2
assert (self.stride != 1) or (inplanes == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
ConvModule(
inplanes,
inplanes,
kernel_size=3,
stride=self.stride,
padding=1,
groups=inplanes,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
inplanes,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
ConvModule(
inplanes if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
groups=branch_features,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, x):
def _inner_forward(x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class ShuffleNetv2(BaseBackbone):
"""ShuffleNetv2 backbone.
Args:
groups (int): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3.
widen_factor (float): Width multiplier - adjusts number of
channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (0, 1, 2, 3).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
groups=3,
widen_factor=1.0,
out_indices=(0, 1, 2),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=False,
with_cp=False):
super(ShuffleNetv2, self).__init__()
self.stage_blocks = [4, 8, 4]
self.groups = groups
self.out_indices = out_indices
assert max(out_indices) < len(self.stage_blocks)
self.frozen_stages = frozen_stages
assert frozen_stages < len(self.stage_blocks)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
if widen_factor == 0.5:
channels = [48, 96, 192, 1024]
elif widen_factor == 1.0:
channels = [116, 232, 464, 1024]
elif widen_factor == 1.5:
channels = [176, 352, 704, 1024]
elif widen_factor == 2.0:
channels = [244, 488, 976, 2048]
else:
raise ValueError(f'widen_factor must in [0.5, 1.0, 1.5, 2.0]. '
f'But received {widen_factor}.')
self.inplanes = 24
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.inplanes,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layers = []
for i, num_blocks in enumerate(self.stage_blocks):
layer = self._make_layer(channels[i], num_blocks)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, layer)
self.layers.append(layer_name)
output_channels = channels[-1]
self.conv2 = ConvModule(
in_channels=self.inplanes,
out_channels=output_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def _make_layer(self, planes, num_blocks):
""" Stack blocks to make a layer.
Args:
planes (int): planes of block.
num_blocks (int): number of blocks.
"""
layers = []
for i in range(num_blocks):
stride = 2 if i == 0 else 1
layers.append(
InvertedResidual(
inplanes=self.inplanes,
planes=planes,
stride=stride,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.inplanes = planes
return nn.Sequential(*layers)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for m in [self.conv1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def train(self, mode=True):
super(ShuffleNetv2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()