[Feature] Add Res2Net backbone and converted weights. (#465)

* Add Res2Net from mmdet, and change it to mmcls style.

* Align structure with official repo

* Support `deep_stem` and `avg_down` option

* Add Res2Net configs

* Add metafile&README and update model zoo

* Add unit tests

* Imporve docstring.

* Improve according to comments.
This commit is contained in:
Ma Zerun 2021-10-20 16:34:22 +08:00 committed by GitHub
parent f68f17e9bb
commit 77a3834531
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 605 additions and 5 deletions

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Res2Net',
depth=101,
scales=4,
base_width=26,
deep_stem=False,
avg_down=False,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Res2Net',
depth=50,
scales=8,
base_width=14,
deep_stem=False,
avg_down=False,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Res2Net',
depth=50,
scales=4,
base_width=26,
deep_stem=False,
avg_down=False,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Res2Net',
depth=50,
scales=6,
base_width=26,
deep_stem=False,
avg_down=False,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Res2Net',
depth=50,
scales=8,
base_width=26,
deep_stem=False,
avg_down=False,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Res2Net',
depth=50,
scales=2,
base_width=48,
deep_stem=False,
avg_down=False,
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

30
configs/res2net/README.md Normal file
View File

@ -0,0 +1,30 @@
# Res2Net: A New Multi-scale Backbone Architecture
<!-- {Res2Net} -->
## Introduction
<!-- [ALGORITHM] -->
```latex
@article{gao2019res2net,
title={Res2Net: A New Multi-scale Backbone Architecture},
author={Gao, Shang-Hua and Cheng, Ming-Ming and Zhao, Kai and Zhang, Xin-Yu and Yang, Ming-Hsuan and Torr, Philip},
journal={IEEE TPAMI},
year={2021},
doi={10.1109/TPAMI.2019.2938758},
}
```
## Pretrain model
The pre-trained models are converted from [official repo](https://github.com/Res2Net/Res2Net-PretrainedModels).
### ImageNet 1k
| Model | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|:---------------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:--------:|
| Res2Net-50-14w-8s\* | 224x224 | 25.06 | 4.22 | 78.14 | 93.85 | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth)|
| Res2Net-50-26w-8s\* | 224x224 | 48.40 | 8.39 | 79.20 | 94.36 | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth)|
| Res2Net-101-26w-4s\* | 224x224 | 45.21 | 8.12 | 79.19 | 94.44 | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth)|
*Models with \* are converted from other repos.*

View File

@ -0,0 +1,67 @@
Collections:
- Name: Res2Net
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- SGD with Momentum
- Weight Decay
Architecture:
- Batch Normalization
- Convolution
- Global Average Pooling
- ReLU
- Res2Net Block
Paper:
Title: 'Res2Net: A New Multi-scale Backbone Architecture'
URL: https://arxiv.org/pdf/1904.01169.pdf
README: configs/res2net/README.md
Models:
- Name: res2net50-w14-s8_3rdparty_8xb32_in1k
Metadata:
FLOPs: 4220000000
Parameters: 25060000
In Collection: Res2Net
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.14
Top 5 Accuracy: 93.85
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth
Converted From:
Weights: https://1drv.ms/u/s!AkxDDnOtroRPdOTqhF8ne_aakDI?e=EVb8Ri
Code: https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net.py#L221
Config: configs/res2net/res2net50-w14-s8_8xb32_in1k.py
- Name: res2net50-w26-s8_3rdparty_8xb32_in1k
Metadata:
FLOPs: 8390000000
Parameters: 48400000
In Collection: Res2Net
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.20
Top 5 Accuracy: 94.36
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth
Converted From:
Weights: https://1drv.ms/u/s!AkxDDnOtroRPdTrAd_Afzc26Z7Q?e=slYqsR
Code: https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net.py#L201
Config: configs/res2net/res2net50-w26-s8_8xb32_in1k.py
- Name: res2net101-w26-s4_3rdparty_8xb32_in1k
Metadata:
FLOPs: 8120000000
Parameters: 45210000
In Collection: Res2Net
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.19
Top 5 Accuracy: 94.44
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth
Converted From:
Weights: https://1drv.ms/u/s!AkxDDnOtroRPcJRgTLkahL0cFYw?e=nwbnic
Code: https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net.py#L181
Config: configs/res2net/res2net101-w26-s4_8xb32_in1k.py

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/res2net101-w26-s4.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/res2net50-w14-s8.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/res2net50-w26-s8.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]

View File

@ -32,6 +32,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.15 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.log.json) | | ResNet-50 | 25.56 | 4.12 | 76.55 | 93.15 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.log.json) |
| ResNet-101 | 44.55 | 7.85 | 78.18 | 94.03 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.log.json) | | ResNet-101 | 44.55 | 7.85 | 78.18 | 94.03 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.log.json) |
| ResNet-152 | 60.19 | 11.58 | 78.63 | 94.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.log.json) | | ResNet-152 | 60.19 | 11.58 | 78.63 | 94.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.log.json) |
| Res2Net-50-14w-8s\* | 25.06 | 4.22 | 78.14 | 93.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/res2net/res2net50-w14-s8_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth) &#124; [log]()|
| Res2Net-50-26w-8s\* | 48.40 | 8.39 | 79.20 | 94.36 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/res2net/res2net50-w26-s8_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth) &#124; [log]()|
| Res2Net-101-26w-4s\* | 45.21 | 8.12 | 79.19 | 94.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/res2net/res2net101-w26-s4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth) &#124; [log]()|
| ResNeSt-50\* | 27.48 | 5.41 | 81.13 | 95.59 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth) &#124; [log]() | | ResNeSt-50\* | 27.48 | 5.41 | 81.13 | 95.59 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth) &#124; [log]() |
| ResNeSt-101\* | 48.28 | 10.27 | 82.32 | 96.24 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth) &#124; [log]() | | ResNeSt-101\* | 48.28 | 10.27 | 82.32 | 96.24 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth) &#124; [log]() |
| ResNeSt-200\* | 70.2 | 17.53 | 82.41 | 96.22 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth) &#124; [log]() | | ResNeSt-200\* | 70.2 | 17.53 | 82.41 | 96.22 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth) &#124; [log]() |

View File

@ -5,6 +5,7 @@ from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3 from .mobilenet_v3 import MobileNetV3
from .regnet import RegNet from .regnet import RegNet
from .repvgg import RepVGG from .repvgg import RepVGG
from .res2net import Res2Net
from .resnest import ResNeSt from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1d from .resnet import ResNet, ResNetV1d
from .resnet_cifar import ResNet_CIFAR from .resnet_cifar import ResNet_CIFAR
@ -23,5 +24,5 @@ __all__ = [
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'RepVGG', 'TIMMBackbone' 'SwinTransformer', 'TNT', 'TIMMBackbone', 'Res2Net', 'RepVGG'
] ]

View File

@ -0,0 +1,306 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import ModuleList, Sequential
from ..builder import BACKBONES
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottle2neck(_Bottleneck):
expansion = 4
def __init__(self,
in_channels,
out_channels,
scales=4,
base_width=26,
base_channels=64,
stage_type='normal',
**kwargs):
"""Bottle2neck block for Res2Net."""
super(Bottle2neck, self).__init__(in_channels, out_channels, **kwargs)
assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.'
mid_channels = out_channels // self.expansion
width = int(math.floor(mid_channels * (base_width / base_channels)))
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width * scales, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.in_channels,
width * scales,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
if stage_type == 'stage':
self.pool = nn.AvgPool2d(
kernel_size=3, stride=self.conv2_stride, padding=1)
self.convs = ModuleList()
self.bns = ModuleList()
for i in range(scales - 1):
self.convs.append(
build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
bias=False))
self.bns.append(
build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
self.conv3 = build_conv_layer(
self.conv_cfg,
width * scales,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.stage_type = stage_type
self.scales = scales
self.width = width
delattr(self, 'conv2')
delattr(self, self.norm2_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
sp = self.convs[0](spx[0].contiguous())
sp = self.relu(self.bns[0](sp))
out = sp
for i in range(1, self.scales - 1):
if self.stage_type == 'stage':
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp.contiguous())
sp = self.relu(self.bns[i](sp))
out = torch.cat((out, sp), 1)
if self.stage_type == 'normal' and self.scales != 1:
out = torch.cat((out, spx[self.scales - 1]), 1)
elif self.stage_type == 'stage' and self.scales != 1:
out = torch.cat((out, self.pool(spx[self.scales - 1])), 1)
out = self.conv3(out)
out = self.norm3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
class Res2Layer(Sequential):
"""Res2Layer to build Res2Net style backbone.
Args:
block (nn.Module): block used to build ResLayer.
inplanes (int): inplanes of block.
planes (int): planes of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottle2neck. Defaults to True.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
scales (int): Scales used in Res2Net. Default: 4
base_width (int): Basic width of each scale. Default: 26
"""
def __init__(self,
block,
in_channels,
out_channels,
num_blocks,
stride=1,
avg_down=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
scales=4,
base_width=26,
**kwargs):
self.block = block
downsample = None
if stride != 1 or in_channels != out_channels:
if avg_down:
downsample = nn.Sequential(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False),
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size=1,
stride=1,
bias=False),
build_norm_layer(norm_cfg, out_channels)[1],
)
else:
downsample = nn.Sequential(
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(norm_cfg, out_channels)[1],
)
layers = []
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
stage_type='stage',
**kwargs))
in_channels = out_channels
for _ in range(1, num_blocks):
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
**kwargs))
super(Res2Layer, self).__init__(*layers)
@BACKBONES.register_module()
class Res2Net(ResNet):
"""Res2Net backbone.
A PyTorch implement of : `Res2Net: A New Multi-scale Backbone
Architecture <https://arxiv.org/pdf/1904.01169.pdf>`_
Args:
depth (int): Depth of Res2Net, choose from {50, 101, 152}.
scales (int): Scales used in Res2Net. Defaults to 4.
base_width (int): Basic width of each scale. Defaults to 26.
in_channels (int): Number of input image channels. Defaults to 3.
num_stages (int): Number of Res2Net stages. Defaults to 4.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
style (str): "pytorch" or "caffe". If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer. Defaults to "pytorch".
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to True.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottle2neck. Defaults to True.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to ``dict(type='BN', requires_grad=True)``.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to True.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
Example:
>>> from mmcls.models import Res2Net
>>> import torch
>>> model = Res2Net(depth=50,
... scales=4,
... base_width=26,
... out_indices=(0, 1, 2, 3))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = model.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings = {
50: (Bottle2neck, (3, 4, 6, 3)),
101: (Bottle2neck, (3, 4, 23, 3)),
152: (Bottle2neck, (3, 8, 36, 3))
}
def __init__(self,
scales=4,
base_width=26,
style='pytorch',
deep_stem=True,
avg_down=True,
init_cfg=None,
**kwargs):
self.scales = scales
self.base_width = base_width
super(Res2Net, self).__init__(
style=style,
deep_stem=deep_stem,
avg_down=avg_down,
init_cfg=init_cfg,
**kwargs)
def make_res_layer(self, **kwargs):
return Res2Layer(
scales=self.scales,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)

View File

@ -396,10 +396,8 @@ class ResNet(BaseBackbone):
Default: ``(1, 2, 2, 2)``. Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage. dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``. Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one out_indices (Sequence[int]): Output from which stages.
stage is specified, a single tensor (feature map) is returned, Default: ``(3, )``.
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer. the first 1x1 conv layer.

View File

@ -2,6 +2,7 @@ Import:
- configs/fp16/metafile.yml - configs/fp16/metafile.yml
- configs/mobilenet_v2/metafile.yml - configs/mobilenet_v2/metafile.yml
- configs/resnet/metafile.yml - configs/resnet/metafile.yml
- configs/res2net/metafile.yml
- configs/resnext/metafile.yml - configs/resnext/metafile.yml
- configs/seresnet/metafile.yml - configs/seresnet/metafile.yml
- configs/shufflenet_v1/metafile.yml - configs/shufflenet_v1/metafile.yml

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import Res2Net
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_resnet_cifar():
# Only support depth 50, 101 and 152
with pytest.raises(KeyError):
Res2Net(depth=18)
# test the feature map size when depth is 50
# and deep_stem=True, avg_down=True
model = Res2Net(
depth=50, out_indices=(0, 1, 2, 3), deep_stem=True, avg_down=True)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model.stem(imgs)
assert feat.shape == (1, 64, 112, 112)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# test the feature map size when depth is 101
# and deep_stem=False, avg_down=False
model = Res2Net(
depth=101, out_indices=(0, 1, 2, 3), deep_stem=False, avg_down=False)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model.conv1(imgs)
assert feat.shape == (1, 64, 112, 112)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test Res2Net with first stage frozen
frozen_stages = 1
model = Res2Net(depth=50, frozen_stages=frozen_stages, deep_stem=False)
model.init_weights()
model.train()
assert check_norm_state([model.norm1], False)
for param in model.conv1.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False