From 77a3834531d2c869fda2e83fa65e17d9c377f5eb Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Wed, 20 Oct 2021 16:34:22 +0800 Subject: [PATCH] [Feature] Add Res2Net backbone and converted weights. (#465) * Add Res2Net from mmdet, and change it to mmcls style. * Align structure with official repo * Support `deep_stem` and `avg_down` option * Add Res2Net configs * Add metafile&README and update model zoo * Add unit tests * Imporve docstring. * Improve according to comments. --- configs/_base_/models/res2net101-w26-s4.py | 18 ++ configs/_base_/models/res2net50-w14-s8.py | 18 ++ configs/_base_/models/res2net50-w26-s4.py | 18 ++ configs/_base_/models/res2net50-w26-s6.py | 18 ++ configs/_base_/models/res2net50-w26-s8.py | 18 ++ configs/_base_/models/res2net50-w48-s2.py | 18 ++ configs/res2net/README.md | 30 ++ configs/res2net/metafile.yml | 67 ++++ .../res2net/res2net101-w26-s4_8xb32_in1k.py | 5 + .../res2net/res2net50-w14-s8_8xb32_in1k.py | 5 + .../res2net/res2net50-w26-s8_8xb32_in1k.py | 5 + docs/model_zoo.md | 3 + mmcls/models/backbones/__init__.py | 3 +- mmcls/models/backbones/res2net.py | 306 ++++++++++++++++++ mmcls/models/backbones/resnet.py | 6 +- model-index.yml | 1 + .../test_backbones/test_res2net.py | 71 ++++ 17 files changed, 605 insertions(+), 5 deletions(-) create mode 100644 configs/_base_/models/res2net101-w26-s4.py create mode 100644 configs/_base_/models/res2net50-w14-s8.py create mode 100644 configs/_base_/models/res2net50-w26-s4.py create mode 100644 configs/_base_/models/res2net50-w26-s6.py create mode 100644 configs/_base_/models/res2net50-w26-s8.py create mode 100644 configs/_base_/models/res2net50-w48-s2.py create mode 100644 configs/res2net/README.md create mode 100644 configs/res2net/metafile.yml create mode 100644 configs/res2net/res2net101-w26-s4_8xb32_in1k.py create mode 100644 configs/res2net/res2net50-w14-s8_8xb32_in1k.py create mode 100644 configs/res2net/res2net50-w26-s8_8xb32_in1k.py create mode 100644 mmcls/models/backbones/res2net.py create mode 100644 tests/test_models/test_backbones/test_res2net.py diff --git a/configs/_base_/models/res2net101-w26-s4.py b/configs/_base_/models/res2net101-w26-s4.py new file mode 100644 index 000000000..3bf64c508 --- /dev/null +++ b/configs/_base_/models/res2net101-w26-s4.py @@ -0,0 +1,18 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='Res2Net', + depth=101, + scales=4, + base_width=26, + deep_stem=False, + avg_down=False, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/res2net50-w14-s8.py b/configs/_base_/models/res2net50-w14-s8.py new file mode 100644 index 000000000..5875142c3 --- /dev/null +++ b/configs/_base_/models/res2net50-w14-s8.py @@ -0,0 +1,18 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='Res2Net', + depth=50, + scales=8, + base_width=14, + deep_stem=False, + avg_down=False, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/res2net50-w26-s4.py b/configs/_base_/models/res2net50-w26-s4.py new file mode 100644 index 000000000..be8fdb585 --- /dev/null +++ b/configs/_base_/models/res2net50-w26-s4.py @@ -0,0 +1,18 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='Res2Net', + depth=50, + scales=4, + base_width=26, + deep_stem=False, + avg_down=False, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/res2net50-w26-s6.py b/configs/_base_/models/res2net50-w26-s6.py new file mode 100644 index 000000000..281b136a6 --- /dev/null +++ b/configs/_base_/models/res2net50-w26-s6.py @@ -0,0 +1,18 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='Res2Net', + depth=50, + scales=6, + base_width=26, + deep_stem=False, + avg_down=False, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/res2net50-w26-s8.py b/configs/_base_/models/res2net50-w26-s8.py new file mode 100644 index 000000000..b4f62f3ed --- /dev/null +++ b/configs/_base_/models/res2net50-w26-s8.py @@ -0,0 +1,18 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='Res2Net', + depth=50, + scales=8, + base_width=26, + deep_stem=False, + avg_down=False, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/res2net50-w48-s2.py b/configs/_base_/models/res2net50-w48-s2.py new file mode 100644 index 000000000..8675c91fa --- /dev/null +++ b/configs/_base_/models/res2net50-w48-s2.py @@ -0,0 +1,18 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='Res2Net', + depth=50, + scales=2, + base_width=48, + deep_stem=False, + avg_down=False, + ), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/res2net/README.md b/configs/res2net/README.md new file mode 100644 index 000000000..befe4ba6d --- /dev/null +++ b/configs/res2net/README.md @@ -0,0 +1,30 @@ +# Res2Net: A New Multi-scale Backbone Architecture + + +## Introduction + + + +```latex +@article{gao2019res2net, + title={Res2Net: A New Multi-scale Backbone Architecture}, + author={Gao, Shang-Hua and Cheng, Ming-Ming and Zhao, Kai and Zhang, Xin-Yu and Yang, Ming-Hsuan and Torr, Philip}, + journal={IEEE TPAMI}, + year={2021}, + doi={10.1109/TPAMI.2019.2938758}, +} +``` + +## Pretrain model + +The pre-trained models are converted from [official repo](https://github.com/Res2Net/Res2Net-PretrainedModels). + +### ImageNet 1k + +| Model | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download | +|:---------------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:--------:| +| Res2Net-50-14w-8s\* | 224x224 | 25.06 | 4.22 | 78.14 | 93.85 | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth)| +| Res2Net-50-26w-8s\* | 224x224 | 48.40 | 8.39 | 79.20 | 94.36 | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth)| +| Res2Net-101-26w-4s\* | 224x224 | 45.21 | 8.12 | 79.19 | 94.44 | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth)| + +*Models with \* are converted from other repos.* diff --git a/configs/res2net/metafile.yml b/configs/res2net/metafile.yml new file mode 100644 index 000000000..dfcda7329 --- /dev/null +++ b/configs/res2net/metafile.yml @@ -0,0 +1,67 @@ +Collections: + - Name: Res2Net + Metadata: + Training Data: ImageNet-1k + Training Techniques: + - SGD with Momentum + - Weight Decay + Architecture: + - Batch Normalization + - Convolution + - Global Average Pooling + - ReLU + - Res2Net Block + Paper: + Title: 'Res2Net: A New Multi-scale Backbone Architecture' + URL: https://arxiv.org/pdf/1904.01169.pdf + README: configs/res2net/README.md + +Models: + - Name: res2net50-w14-s8_3rdparty_8xb32_in1k + Metadata: + FLOPs: 4220000000 + Parameters: 25060000 + In Collection: Res2Net + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 78.14 + Top 5 Accuracy: 93.85 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth + Converted From: + Weights: https://1drv.ms/u/s!AkxDDnOtroRPdOTqhF8ne_aakDI?e=EVb8Ri + Code: https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net.py#L221 + Config: configs/res2net/res2net50-w14-s8_8xb32_in1k.py + - Name: res2net50-w26-s8_3rdparty_8xb32_in1k + Metadata: + FLOPs: 8390000000 + Parameters: 48400000 + In Collection: Res2Net + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 79.20 + Top 5 Accuracy: 94.36 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth + Converted From: + Weights: https://1drv.ms/u/s!AkxDDnOtroRPdTrAd_Afzc26Z7Q?e=slYqsR + Code: https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net.py#L201 + Config: configs/res2net/res2net50-w26-s8_8xb32_in1k.py + - Name: res2net101-w26-s4_3rdparty_8xb32_in1k + Metadata: + FLOPs: 8120000000 + Parameters: 45210000 + In Collection: Res2Net + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 79.19 + Top 5 Accuracy: 94.44 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth + Converted From: + Weights: https://1drv.ms/u/s!AkxDDnOtroRPcJRgTLkahL0cFYw?e=nwbnic + Code: https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net.py#L181 + Config: configs/res2net/res2net101-w26-s4_8xb32_in1k.py diff --git a/configs/res2net/res2net101-w26-s4_8xb32_in1k.py b/configs/res2net/res2net101-w26-s4_8xb32_in1k.py new file mode 100644 index 000000000..7ebe9e94d --- /dev/null +++ b/configs/res2net/res2net101-w26-s4_8xb32_in1k.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/res2net101-w26-s4.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] diff --git a/configs/res2net/res2net50-w14-s8_8xb32_in1k.py b/configs/res2net/res2net50-w14-s8_8xb32_in1k.py new file mode 100644 index 000000000..56cc02e3b --- /dev/null +++ b/configs/res2net/res2net50-w14-s8_8xb32_in1k.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/res2net50-w14-s8.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] diff --git a/configs/res2net/res2net50-w26-s8_8xb32_in1k.py b/configs/res2net/res2net50-w26-s8_8xb32_in1k.py new file mode 100644 index 000000000..d7dcbeb91 --- /dev/null +++ b/configs/res2net/res2net50-w26-s8_8xb32_in1k.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/res2net50-w26-s8.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] diff --git a/docs/model_zoo.md b/docs/model_zoo.md index f54ef8155..e1942fe7c 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -32,6 +32,9 @@ The ResNet family models below are trained by standard data augmentations, i.e., | ResNet-50 | 25.56 | 4.12 | 76.55 | 93.15 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.log.json) | | ResNet-101 | 44.55 | 7.85 | 78.18 | 94.03 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.log.json) | | ResNet-152 | 60.19 | 11.58 | 78.63 | 94.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_b32x8_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.log.json) | +| Res2Net-50-14w-8s\* | 25.06 | 4.22 | 78.14 | 93.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/res2net/res2net50-w14-s8_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth) | [log]()| +| Res2Net-50-26w-8s\* | 48.40 | 8.39 | 79.20 | 94.36 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/res2net/res2net50-w26-s8_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth) | [log]()| +| Res2Net-101-26w-4s\* | 45.21 | 8.12 | 79.19 | 94.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/res2net/res2net101-w26-s4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth) | [log]()| | ResNeSt-50\* | 27.48 | 5.41 | 81.13 | 95.59 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth) | [log]() | | ResNeSt-101\* | 48.28 | 10.27 | 82.32 | 96.24 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth) | [log]() | | ResNeSt-200\* | 70.2 | 17.53 | 82.41 | 96.22 | | [model](https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth) | [log]() | diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 866dae5b2..526d3450e 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -5,6 +5,7 @@ from .mobilenet_v2 import MobileNetV2 from .mobilenet_v3 import MobileNetV3 from .regnet import RegNet from .repvgg import RepVGG +from .res2net import Res2Net from .resnest import ResNeSt from .resnet import ResNet, ResNetV1d from .resnet_cifar import ResNet_CIFAR @@ -23,5 +24,5 @@ __all__ = [ 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', - 'SwinTransformer', 'TNT', 'RepVGG', 'TIMMBackbone' + 'SwinTransformer', 'TNT', 'TIMMBackbone', 'Res2Net', 'RepVGG' ] diff --git a/mmcls/models/backbones/res2net.py b/mmcls/models/backbones/res2net.py new file mode 100644 index 000000000..491b6f471 --- /dev/null +++ b/mmcls/models/backbones/res2net.py @@ -0,0 +1,306 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner import ModuleList, Sequential + +from ..builder import BACKBONES +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottle2neck(_Bottleneck): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + scales=4, + base_width=26, + base_channels=64, + stage_type='normal', + **kwargs): + """Bottle2neck block for Res2Net.""" + super(Bottle2neck, self).__init__(in_channels, out_channels, **kwargs) + assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.' + + mid_channels = out_channels // self.expansion + width = int(math.floor(mid_channels * (base_width / base_channels))) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width * scales, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + width * scales, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + + if stage_type == 'stage': + self.pool = nn.AvgPool2d( + kernel_size=3, stride=self.conv2_stride, padding=1) + + self.convs = ModuleList() + self.bns = ModuleList() + for i in range(scales - 1): + self.convs.append( + build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + bias=False)) + self.bns.append( + build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1]) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width * scales, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.stage_type = stage_type + self.scales = scales + self.width = width + delattr(self, 'conv2') + delattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + sp = self.convs[0](spx[0].contiguous()) + sp = self.relu(self.bns[0](sp)) + out = sp + for i in range(1, self.scales - 1): + if self.stage_type == 'stage': + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp.contiguous()) + sp = self.relu(self.bns[i](sp)) + out = torch.cat((out, sp), 1) + + if self.stage_type == 'normal' and self.scales != 1: + out = torch.cat((out, spx[self.scales - 1]), 1) + elif self.stage_type == 'stage' and self.scales != 1: + out = torch.cat((out, self.pool(spx[self.scales - 1])), 1) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Res2Layer(Sequential): + """Res2Layer to build Res2Net style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Defaults to True. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + scales (int): Scales used in Res2Net. Default: 4 + base_width (int): Basic width of each scale. Default: 26 + """ + + def __init__(self, + block, + in_channels, + out_channels, + num_blocks, + stride=1, + avg_down=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + scales=4, + base_width=26, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or in_channels != out_channels: + if avg_down: + downsample = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False), + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + ) + else: + downsample = nn.Sequential( + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + layers = [] + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + stage_type='stage', + **kwargs)) + in_channels = out_channels + for _ in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + **kwargs)) + super(Res2Layer, self).__init__(*layers) + + +@BACKBONES.register_module() +class Res2Net(ResNet): + """Res2Net backbone. + + A PyTorch implement of : `Res2Net: A New Multi-scale Backbone + Architecture `_ + + Args: + depth (int): Depth of Res2Net, choose from {50, 101, 152}. + scales (int): Scales used in Res2Net. Defaults to 4. + base_width (int): Basic width of each scale. Defaults to 26. + in_channels (int): Number of input image channels. Defaults to 3. + num_stages (int): Number of Res2Net stages. Defaults to 4. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + style (str): "pytorch" or "caffe". If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Defaults to "pytorch". + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to True. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN', requires_grad=True)``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + + Example: + >>> from mmcls.models import Res2Net + >>> import torch + >>> model = Res2Net(depth=50, + ... scales=4, + ... base_width=26, + ... out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = model.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottle2neck, (3, 4, 6, 3)), + 101: (Bottle2neck, (3, 4, 23, 3)), + 152: (Bottle2neck, (3, 8, 36, 3)) + } + + def __init__(self, + scales=4, + base_width=26, + style='pytorch', + deep_stem=True, + avg_down=True, + init_cfg=None, + **kwargs): + self.scales = scales + self.base_width = base_width + super(Res2Net, self).__init__( + style=style, + deep_stem=deep_stem, + avg_down=avg_down, + init_cfg=init_cfg, + **kwargs) + + def make_res_layer(self, **kwargs): + return Res2Layer( + scales=self.scales, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmcls/models/backbones/resnet.py b/mmcls/models/backbones/resnet.py index 6b0d5890d..35dbf98a4 100644 --- a/mmcls/models/backbones/resnet.py +++ b/mmcls/models/backbones/resnet.py @@ -396,10 +396,8 @@ class ResNet(BaseBackbone): Default: ``(1, 2, 2, 2)``. dilations (Sequence[int]): Dilation of each stage. Default: ``(1, 1, 1, 1)``. - out_indices (Sequence[int]): Output from which stages. If only one - stage is specified, a single tensor (feature map) is returned, - otherwise multiple stages are specified, a tuple of tensors will - be returned. Default: ``(3, )``. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. diff --git a/model-index.yml b/model-index.yml index 45248d8bc..814a3c854 100644 --- a/model-index.yml +++ b/model-index.yml @@ -2,6 +2,7 @@ Import: - configs/fp16/metafile.yml - configs/mobilenet_v2/metafile.yml - configs/resnet/metafile.yml + - configs/res2net/metafile.yml - configs/resnext/metafile.yml - configs/seresnet/metafile.yml - configs/shufflenet_v1/metafile.yml diff --git a/tests/test_models/test_backbones/test_res2net.py b/tests/test_models/test_backbones/test_res2net.py new file mode 100644 index 000000000..173d3e628 --- /dev/null +++ b/tests/test_models/test_backbones/test_res2net.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmcls.models.backbones import Res2Net + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_resnet_cifar(): + # Only support depth 50, 101 and 152 + with pytest.raises(KeyError): + Res2Net(depth=18) + + # test the feature map size when depth is 50 + # and deep_stem=True, avg_down=True + model = Res2Net( + depth=50, out_indices=(0, 1, 2, 3), deep_stem=True, avg_down=True) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model.stem(imgs) + assert feat.shape == (1, 64, 112, 112) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == (1, 256, 56, 56) + assert feat[1].shape == (1, 512, 28, 28) + assert feat[2].shape == (1, 1024, 14, 14) + assert feat[3].shape == (1, 2048, 7, 7) + + # test the feature map size when depth is 101 + # and deep_stem=False, avg_down=False + model = Res2Net( + depth=101, out_indices=(0, 1, 2, 3), deep_stem=False, avg_down=False) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model.conv1(imgs) + assert feat.shape == (1, 64, 112, 112) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == (1, 256, 56, 56) + assert feat[1].shape == (1, 512, 28, 28) + assert feat[2].shape == (1, 1024, 14, 14) + assert feat[3].shape == (1, 2048, 7, 7) + + # Test Res2Net with first stage frozen + frozen_stages = 1 + model = Res2Net(depth=50, frozen_stages=frozen_stages, deep_stem=False) + model.init_weights() + model.train() + assert check_norm_state([model.norm1], False) + for param in model.conv1.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False