Merge branch 'regnet' into 'master'

Add RegNet

See merge request open-mmlab/mmclassification!24
This commit is contained in:
chenkai 2020-07-01 14:21:46 +08:00
commit a48ffaaa0a
4 changed files with 407 additions and 3 deletions

View File

@ -42,4 +42,4 @@ test:pytorch1.3-cuda10:
test:pat0.6.0dev-cuda9: test:pat0.6.0dev-cuda9:
image: $PARROTS_IMAGE image: $PARROTS_IMAGE
<<: *test_template_def <<: *test_template_def

View File

@ -1,5 +1,6 @@
from .mobilenet_v2 import MobileNetV2 from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetv3 from .mobilenet_v3 import MobileNetv3
from .regnet import RegNet
from .resnet import ResNet, ResNetV1d from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt from .resnext import ResNeXt
from .seresnet import SEResNet from .seresnet import SEResNet
@ -8,6 +9,6 @@ from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2 from .shufflenet_v2 import ShuffleNetV2
__all__ = [ __all__ = [
'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'SEResNet', 'SEResNeXt', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'SEResNet',
'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3' 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3'
] ]

View File

@ -0,0 +1,312 @@
import numpy as np
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .resnet import ResNet
from .resnext import Bottleneck
@BACKBONES.register_module()
class RegNet(ResNet):
"""RegNet backbone.
More details can be found in `paper <https://arxiv.org/abs/2003.13678>`_ .
Args:
arch (dict): The parameter of RegNets.
- w0 (int): initial width
- wa (float): slope of width
- wm (float): quantization parameter to quantize the width
- depth (int): depth of the backbone
- group_w (int): width of group
- bot_mul (float): bottleneck ratio, i.e. expansion of bottlneck.
strides (Sequence[int]): Strides of the first block of each stage.
base_channels (int): Base channels after stem layer.
in_channels (int): Number of input image channels. Default: 3.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer. Default: "pytorch".
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters. Default: -1.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmdet.models import RegNet
>>> import torch
>>> self = RegNet(
arch=dict(
w0=88,
wa=26.31,
wm=2.25,
group_w=48,
depth=25,
bot_mul=1.0))
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 96, 8, 8)
(1, 192, 4, 4)
(1, 432, 2, 2)
(1, 1008, 1, 1)
"""
arch_settings = {
'regnetx_400mf':
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
'regnetx_800mf':
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
'regnetx_1.6gf':
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
'regnetx_3.2gf':
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
'regnetx_4.0gf':
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
'regnetx_6.4gf':
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
'regnetx_8.0gf':
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
'regnetx_12gf':
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
}
def __init__(self,
arch,
in_channels=3,
stem_channels=32,
base_channels=32,
strides=(2, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(3, ),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
# Generate RegNet parameters first
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'"arch": "{arch}" is not one of the' \
' arch_settings'
arch = self.arch_settings[arch]
elif not isinstance(arch, dict):
raise TypeError('Expect "arch" to be either a string '
f'or a dict, got {type(arch)}')
widths, num_stages = self.generate_regnet(
arch['w0'],
arch['wa'],
arch['wm'],
arch['depth'],
)
# Convert to per stage format
stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
# Generate group widths and bot muls
group_widths = [arch['group_w'] for _ in range(num_stages)]
self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
# Adjust the compatibility of stage_widths and group_widths
stage_widths, group_widths = self.adjust_width_group(
stage_widths, self.bottleneck_ratio, group_widths)
# Group params by stage
self.stage_widths = stage_widths
self.group_widths = group_widths
self.depth = sum(stage_blocks)
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
if self.deep_stem:
raise NotImplementedError(
'deep_stem has not been implemented for RegNet')
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.zero_init_residual = zero_init_residual
self.stage_blocks = stage_blocks[:num_stages]
self._make_stem_layer(in_channels, stem_channels)
_in_channels = stem_channels
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[i]
dilation = self.dilations[i]
group_width = self.group_widths[i]
width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
stage_groups = width // group_width
res_layer = self.make_res_layer(
block=Bottleneck,
num_blocks=num_blocks,
in_channels=_in_channels,
out_channels=self.stage_widths[i],
expansion=1,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=self.with_cp,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
base_channels=self.stage_widths[i],
groups=stage_groups,
width_per_group=group_width)
_in_channels = self.stage_widths[i]
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = stage_widths[-1]
def _make_stem_layer(self, in_channels, base_channels):
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
base_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, base_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
def generate_regnet(self,
initial_width,
width_slope,
width_parameter,
depth,
divisor=8):
"""Generates per block width from RegNet parameters.
Args:
initial_width ([int]): Initial width of the backbone
width_slope ([float]): Slope of the quantized linear function
width_parameter ([int]): Parameter used to quantize the width.
depth ([int]): Depth of the backbone.
divisor (int, optional): The divisor of channels. Defaults to 8.
Returns:
list, int: return a list of widths of each stage and the number of
stages
"""
assert width_slope >= 0
assert initial_width > 0
assert width_parameter > 1
assert initial_width % divisor == 0
widths_cont = np.arange(depth) * width_slope + initial_width
ks = np.round(
np.log(widths_cont / initial_width) / np.log(width_parameter))
widths = initial_width * np.power(width_parameter, ks)
widths = np.round(np.divide(widths, divisor)) * divisor
num_stages = len(np.unique(widths))
widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
return widths, num_stages
@staticmethod
def quantize_float(number, divisor):
"""Converts a float to closest non-zero int divisible by divior.
Args:
number (int): Original number to be quantized.
divisor (int): Divisor used to quantize the number.
Returns:
int: quantized number that is divisible by devisor.
"""
return int(round(number / divisor) * divisor)
def adjust_width_group(self, widths, bottleneck_ratio, groups):
"""Adjusts the compatibility of widths and groups.
Args:
widths (list[int]): Width of each stage.
bottleneck_ratio (float): Bottleneck ratio.
groups (int): number of groups in each stage
Returns:
tuple(list): The adjusted widths and groups of each stage.
"""
bottleneck_width = [
int(w * b) for w, b in zip(widths, bottleneck_ratio)
]
groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
bottleneck_width = [
self.quantize_float(w_bot, g)
for w_bot, g in zip(bottleneck_width, groups)
]
widths = [
int(w_bot / b)
for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
]
return widths, groups
def get_stages_from_blocks(self, widths):
"""Gets widths/stage_blocks of network at each stage
Args:
widths (list[int]): Width in each stage.
Returns:
tuple(list): width and depth of each stage
"""
width_diff = [
width != width_prev
for width, width_prev in zip(widths + [0], [0] + widths)
]
stage_widths = [
width for width, diff in zip(widths, width_diff[:-1]) if diff
]
stage_blocks = np.diff([
depth for depth, diff in zip(range(len(width_diff)), width_diff)
if diff
]).tolist()
return stage_widths, stage_blocks
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)

View File

@ -0,0 +1,91 @@
import pytest
import torch
from mmcls.models.backbones import RegNet
regnet_test_data = [
('regnetx_400mf',
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22,
bot_mul=1.0), [32, 64, 160, 384]),
('regnetx_800mf',
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16,
bot_mul=1.0), [64, 128, 288, 672]),
('regnetx_1.6gf',
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18,
bot_mul=1.0), [72, 168, 408, 912]),
('regnetx_3.2gf',
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25,
bot_mul=1.0), [96, 192, 432, 1008]),
('regnetx_4.0gf',
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23,
bot_mul=1.0), [80, 240, 560, 1360]),
('regnetx_6.4gf',
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17,
bot_mul=1.0), [168, 392, 784, 1624]),
('regnetx_8.0gf',
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23,
bot_mul=1.0), [80, 240, 720, 1920]),
('regnetx_12gf',
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19,
bot_mul=1.0), [224, 448, 896, 2240]),
]
@pytest.mark.parametrize('arch_name,arch,out_channels', regnet_test_data)
def test_regnet_backbone(arch_name, arch, out_channels):
with pytest.raises(AssertionError):
# ResNeXt depth should be in [50, 101, 152]
RegNet(arch_name + '233')
# output the last feature map
model = RegNet(arch_name)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert isinstance(feat, torch.Tensor)
assert feat.shape == (1, out_channels[-1], 7, 7)
# output feature map of all stages
model = RegNet(arch_name, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, out_channels[0], 56, 56)
assert feat[1].shape == (1, out_channels[1], 28, 28)
assert feat[2].shape == (1, out_channels[2], 14, 14)
assert feat[3].shape == (1, out_channels[3], 7, 7)
@pytest.mark.parametrize('arch_name,arch,out_channels', regnet_test_data)
def test_custom_arch(arch_name, arch, out_channels):
# output the last feature map
model = RegNet(arch)
model.init_weights()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert isinstance(feat, torch.Tensor)
assert feat.shape == (1, out_channels[-1], 7, 7)
# output feature map of all stages
model = RegNet(arch, out_indices=(0, 1, 2, 3))
model.init_weights()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, out_channels[0], 56, 56)
assert feat[1].shape == (1, out_channels[1], 28, 28)
assert feat[2].shape == (1, out_channels[2], 14, 14)
assert feat[3].shape == (1, out_channels[3], 7, 7)
def test_exception():
# arch must be a str or dict
with pytest.raises(TypeError):
_ = RegNet(50)