mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
Merge branch 'regnet' into 'master'
Add RegNet See merge request open-mmlab/mmclassification!24
This commit is contained in:
commit
a48ffaaa0a
@ -42,4 +42,4 @@ test:pytorch1.3-cuda10:
|
||||
|
||||
test:pat0.6.0dev-cuda9:
|
||||
image: $PARROTS_IMAGE
|
||||
<<: *test_template_def
|
||||
<<: *test_template_def
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetv3
|
||||
from .regnet import RegNet
|
||||
from .resnet import ResNet, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .seresnet import SEResNet
|
||||
@ -8,6 +9,6 @@ from .shufflenet_v1 import ShuffleNetV1
|
||||
from .shufflenet_v2 import ShuffleNetV2
|
||||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'SEResNet', 'SEResNeXt',
|
||||
'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3'
|
||||
'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'SEResNet',
|
||||
'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3'
|
||||
]
|
||||
|
312
mmcls/models/backbones/regnet.py
Normal file
312
mmcls/models/backbones/regnet.py
Normal file
@ -0,0 +1,312 @@
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from .resnet import ResNet
|
||||
from .resnext import Bottleneck
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class RegNet(ResNet):
|
||||
"""RegNet backbone.
|
||||
|
||||
More details can be found in `paper <https://arxiv.org/abs/2003.13678>`_ .
|
||||
|
||||
Args:
|
||||
arch (dict): The parameter of RegNets.
|
||||
- w0 (int): initial width
|
||||
- wa (float): slope of width
|
||||
- wm (float): quantization parameter to quantize the width
|
||||
- depth (int): depth of the backbone
|
||||
- group_w (int): width of group
|
||||
- bot_mul (float): bottleneck ratio, i.e. expansion of bottlneck.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
base_channels (int): Base channels after stem layer.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer. Default: "pytorch".
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters. Default: -1.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
|
||||
Example:
|
||||
>>> from mmdet.models import RegNet
|
||||
>>> import torch
|
||||
>>> self = RegNet(
|
||||
arch=dict(
|
||||
w0=88,
|
||||
wa=26.31,
|
||||
wm=2.25,
|
||||
group_w=48,
|
||||
depth=25,
|
||||
bot_mul=1.0))
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 96, 8, 8)
|
||||
(1, 192, 4, 4)
|
||||
(1, 432, 2, 2)
|
||||
(1, 1008, 1, 1)
|
||||
"""
|
||||
arch_settings = {
|
||||
'regnetx_400mf':
|
||||
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
|
||||
'regnetx_800mf':
|
||||
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
|
||||
'regnetx_1.6gf':
|
||||
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
|
||||
'regnetx_3.2gf':
|
||||
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
|
||||
'regnetx_4.0gf':
|
||||
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
|
||||
'regnetx_6.4gf':
|
||||
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
|
||||
'regnetx_8.0gf':
|
||||
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
|
||||
'regnetx_12gf':
|
||||
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch,
|
||||
in_channels=3,
|
||||
stem_channels=32,
|
||||
base_channels=32,
|
||||
strides=(2, 2, 2, 2),
|
||||
dilations=(1, 1, 1, 1),
|
||||
out_indices=(3, ),
|
||||
style='pytorch',
|
||||
deep_stem=False,
|
||||
avg_down=False,
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=True):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
# Generate RegNet parameters first
|
||||
if isinstance(arch, str):
|
||||
assert arch in self.arch_settings, \
|
||||
f'"arch": "{arch}" is not one of the' \
|
||||
' arch_settings'
|
||||
arch = self.arch_settings[arch]
|
||||
elif not isinstance(arch, dict):
|
||||
raise TypeError('Expect "arch" to be either a string '
|
||||
f'or a dict, got {type(arch)}')
|
||||
|
||||
widths, num_stages = self.generate_regnet(
|
||||
arch['w0'],
|
||||
arch['wa'],
|
||||
arch['wm'],
|
||||
arch['depth'],
|
||||
)
|
||||
# Convert to per stage format
|
||||
stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
|
||||
# Generate group widths and bot muls
|
||||
group_widths = [arch['group_w'] for _ in range(num_stages)]
|
||||
self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
|
||||
# Adjust the compatibility of stage_widths and group_widths
|
||||
stage_widths, group_widths = self.adjust_width_group(
|
||||
stage_widths, self.bottleneck_ratio, group_widths)
|
||||
|
||||
# Group params by stage
|
||||
self.stage_widths = stage_widths
|
||||
self.group_widths = group_widths
|
||||
self.depth = sum(stage_blocks)
|
||||
self.stem_channels = stem_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_stages = num_stages
|
||||
assert num_stages >= 1 and num_stages <= 4
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == num_stages
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < num_stages
|
||||
self.style = style
|
||||
self.deep_stem = deep_stem
|
||||
if self.deep_stem:
|
||||
raise NotImplementedError(
|
||||
'deep_stem has not been implemented for RegNet')
|
||||
self.avg_down = avg_down
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.with_cp = with_cp
|
||||
self.norm_eval = norm_eval
|
||||
self.zero_init_residual = zero_init_residual
|
||||
self.stage_blocks = stage_blocks[:num_stages]
|
||||
|
||||
self._make_stem_layer(in_channels, stem_channels)
|
||||
|
||||
_in_channels = stem_channels
|
||||
self.res_layers = []
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
stride = self.strides[i]
|
||||
dilation = self.dilations[i]
|
||||
group_width = self.group_widths[i]
|
||||
width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
|
||||
stage_groups = width // group_width
|
||||
|
||||
res_layer = self.make_res_layer(
|
||||
block=Bottleneck,
|
||||
num_blocks=num_blocks,
|
||||
in_channels=_in_channels,
|
||||
out_channels=self.stage_widths[i],
|
||||
expansion=1,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
style=self.style,
|
||||
avg_down=self.avg_down,
|
||||
with_cp=self.with_cp,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
base_channels=self.stage_widths[i],
|
||||
groups=stage_groups,
|
||||
width_per_group=group_width)
|
||||
_in_channels = self.stage_widths[i]
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
self.feat_dim = stage_widths[-1]
|
||||
|
||||
def _make_stem_layer(self, in_channels, base_channels):
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
base_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, base_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def generate_regnet(self,
|
||||
initial_width,
|
||||
width_slope,
|
||||
width_parameter,
|
||||
depth,
|
||||
divisor=8):
|
||||
"""Generates per block width from RegNet parameters.
|
||||
|
||||
Args:
|
||||
initial_width ([int]): Initial width of the backbone
|
||||
width_slope ([float]): Slope of the quantized linear function
|
||||
width_parameter ([int]): Parameter used to quantize the width.
|
||||
depth ([int]): Depth of the backbone.
|
||||
divisor (int, optional): The divisor of channels. Defaults to 8.
|
||||
|
||||
Returns:
|
||||
list, int: return a list of widths of each stage and the number of
|
||||
stages
|
||||
"""
|
||||
assert width_slope >= 0
|
||||
assert initial_width > 0
|
||||
assert width_parameter > 1
|
||||
assert initial_width % divisor == 0
|
||||
widths_cont = np.arange(depth) * width_slope + initial_width
|
||||
ks = np.round(
|
||||
np.log(widths_cont / initial_width) / np.log(width_parameter))
|
||||
widths = initial_width * np.power(width_parameter, ks)
|
||||
widths = np.round(np.divide(widths, divisor)) * divisor
|
||||
num_stages = len(np.unique(widths))
|
||||
widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
|
||||
return widths, num_stages
|
||||
|
||||
@staticmethod
|
||||
def quantize_float(number, divisor):
|
||||
"""Converts a float to closest non-zero int divisible by divior.
|
||||
|
||||
Args:
|
||||
number (int): Original number to be quantized.
|
||||
divisor (int): Divisor used to quantize the number.
|
||||
|
||||
Returns:
|
||||
int: quantized number that is divisible by devisor.
|
||||
"""
|
||||
return int(round(number / divisor) * divisor)
|
||||
|
||||
def adjust_width_group(self, widths, bottleneck_ratio, groups):
|
||||
"""Adjusts the compatibility of widths and groups.
|
||||
|
||||
Args:
|
||||
widths (list[int]): Width of each stage.
|
||||
bottleneck_ratio (float): Bottleneck ratio.
|
||||
groups (int): number of groups in each stage
|
||||
|
||||
Returns:
|
||||
tuple(list): The adjusted widths and groups of each stage.
|
||||
"""
|
||||
bottleneck_width = [
|
||||
int(w * b) for w, b in zip(widths, bottleneck_ratio)
|
||||
]
|
||||
groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
|
||||
bottleneck_width = [
|
||||
self.quantize_float(w_bot, g)
|
||||
for w_bot, g in zip(bottleneck_width, groups)
|
||||
]
|
||||
widths = [
|
||||
int(w_bot / b)
|
||||
for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
|
||||
]
|
||||
return widths, groups
|
||||
|
||||
def get_stages_from_blocks(self, widths):
|
||||
"""Gets widths/stage_blocks of network at each stage
|
||||
|
||||
Args:
|
||||
widths (list[int]): Width in each stage.
|
||||
|
||||
Returns:
|
||||
tuple(list): width and depth of each stage
|
||||
"""
|
||||
width_diff = [
|
||||
width != width_prev
|
||||
for width, width_prev in zip(widths + [0], [0] + widths)
|
||||
]
|
||||
stage_widths = [
|
||||
width for width, diff in zip(widths, width_diff[:-1]) if diff
|
||||
]
|
||||
stage_blocks = np.diff([
|
||||
depth for depth, diff in zip(range(len(width_diff)), width_diff)
|
||||
if diff
|
||||
]).tolist()
|
||||
return stage_widths, stage_blocks
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.res_layers):
|
||||
res_layer = getattr(self, layer_name)
|
||||
x = res_layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
if len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
91
tests/test_backbones/test_regnet.py
Normal file
91
tests/test_backbones/test_regnet.py
Normal file
@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.backbones import RegNet
|
||||
|
||||
regnet_test_data = [
|
||||
('regnetx_400mf',
|
||||
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22,
|
||||
bot_mul=1.0), [32, 64, 160, 384]),
|
||||
('regnetx_800mf',
|
||||
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16,
|
||||
bot_mul=1.0), [64, 128, 288, 672]),
|
||||
('regnetx_1.6gf',
|
||||
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18,
|
||||
bot_mul=1.0), [72, 168, 408, 912]),
|
||||
('regnetx_3.2gf',
|
||||
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25,
|
||||
bot_mul=1.0), [96, 192, 432, 1008]),
|
||||
('regnetx_4.0gf',
|
||||
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23,
|
||||
bot_mul=1.0), [80, 240, 560, 1360]),
|
||||
('regnetx_6.4gf',
|
||||
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17,
|
||||
bot_mul=1.0), [168, 392, 784, 1624]),
|
||||
('regnetx_8.0gf',
|
||||
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23,
|
||||
bot_mul=1.0), [80, 240, 720, 1920]),
|
||||
('regnetx_12gf',
|
||||
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19,
|
||||
bot_mul=1.0), [224, 448, 896, 2240]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('arch_name,arch,out_channels', regnet_test_data)
|
||||
def test_regnet_backbone(arch_name, arch, out_channels):
|
||||
with pytest.raises(AssertionError):
|
||||
# ResNeXt depth should be in [50, 101, 152]
|
||||
RegNet(arch_name + '233')
|
||||
|
||||
# output the last feature map
|
||||
model = RegNet(arch_name)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert isinstance(feat, torch.Tensor)
|
||||
assert feat.shape == (1, out_channels[-1], 7, 7)
|
||||
|
||||
# output feature map of all stages
|
||||
model = RegNet(arch_name, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, out_channels[0], 56, 56)
|
||||
assert feat[1].shape == (1, out_channels[1], 28, 28)
|
||||
assert feat[2].shape == (1, out_channels[2], 14, 14)
|
||||
assert feat[3].shape == (1, out_channels[3], 7, 7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('arch_name,arch,out_channels', regnet_test_data)
|
||||
def test_custom_arch(arch_name, arch, out_channels):
|
||||
# output the last feature map
|
||||
model = RegNet(arch)
|
||||
model.init_weights()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert isinstance(feat, torch.Tensor)
|
||||
assert feat.shape == (1, out_channels[-1], 7, 7)
|
||||
|
||||
# output feature map of all stages
|
||||
model = RegNet(arch, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, out_channels[0], 56, 56)
|
||||
assert feat[1].shape == (1, out_channels[1], 28, 28)
|
||||
assert feat[2].shape == (1, out_channels[2], 14, 14)
|
||||
assert feat[3].shape == (1, out_channels[3], 7, 7)
|
||||
|
||||
|
||||
def test_exception():
|
||||
# arch must be a str or dict
|
||||
with pytest.raises(TypeError):
|
||||
_ = RegNet(50)
|
Loading…
x
Reference in New Issue
Block a user