mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Support ResNeSt backbone (#47)
* Support ResNeSt backbone * fixed avg_down * add docstring and test * update table * update docs and tests * fixed test * rename * refactor splits
This commit is contained in:
parent
df37f801b6
commit
c8b250df4a
@ -54,7 +54,8 @@ Results and models are available in the [model zoo](docs/model_zoo.md).
|
||||
Supported backbones:
|
||||
- [x] ResNet
|
||||
- [x] ResNeXt
|
||||
- [x] HRNet
|
||||
- [x] [HRNet](configs/hrnet/README.md)
|
||||
- [x] [ResNeSt](configs/resnest/README.md)
|
||||
|
||||
Supported methods:
|
||||
- [x] [FCN](configs/fcn)
|
||||
|
30
configs/resnest/README.md
Normal file
30
configs/resnest/README.md
Normal file
@ -0,0 +1,30 @@
|
||||
# ResNeSt: Split-Attention Networks
|
||||
|
||||
## Introduction
|
||||
|
||||
```
|
||||
@article{zhang2020resnest,
|
||||
title={ResNeSt: Split-Attention Networks},
|
||||
author={Zhang, Hang and Wu, Chongruo and Zhang, Zhongyue and Zhu, Yi and Zhang, Zhi and Lin, Haibin and Sun, Yue and He, Tong and Muller, Jonas and Manmatha, R. and Li, Mu and Smola, Alexander},
|
||||
journal={arXiv preprint arXiv:2004.08955},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| FCN | S-101-D8 | 512x1024 | 80000 | 11.4 | 2.39 | 77.56 | 78.98 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x1024_80k_cityscapes/fcn_s101-d8_512x1024_80k_cityscapes_20200807_140631-f8d155b3.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x1024_80k_cityscapes/fcn_s101-d8_512x1024_80k_cityscapes-20200807_140631.log.json) |
|
||||
| PSPNet | S-101-D8 | 512x1024 | 80000 | 11.8 | 2.52 | 78.57 | 79.19 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x1024_80k_cityscapes/pspnet_s101-d8_512x1024_80k_cityscapes_20200807_140631-c75f3b99.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x1024_80k_cityscapes/pspnet_s101-d8_512x1024_80k_cityscapes-20200807_140631.log.json) |
|
||||
| DeepLabV3 | S-101-D8 | 512x1024 | 80000 | 11.9 | 1.88 | 79.67 | 80.51 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes/deeplabv3_s101-d8_512x1024_80k_cityscapes_20200807_144429-b73c4270.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes/deeplabv3_s101-d8_512x1024_80k_cityscapes-20200807_144429.log.json) |
|
||||
| DeepLabV3+ | S-101-D8 | 512x1024 | 80000 | 13.2 | 2.36 | 79.62 | 80.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes/deeplabv3plus_s101-d8_512x1024_80k_cityscapes_20200807_144429-1239eb43.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes/deeplabv3plus_s101-d8_512x1024_80k_cityscapes-20200807_144429.log.json) |
|
||||
|
||||
### ADE20k
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| FCN | S-101-D8 | 512x512 | 160000 | 14.2 | 12.86 | 45.62 | 46.16 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x512_160k_ade20k/fcn_s101-d8_512x512_160k_ade20k_20200807_145416-d3160329.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x512_160k_ade20k/fcn_s101-d8_512x512_160k_ade20k-20200807_145416.log.json) |
|
||||
| PSPNet | S-101-D8 | 512x512 | 160000 | 14.2 | 13.02 | 45.44 | 46.28 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x512_160k_ade20k/pspnet_s101-d8_512x512_160k_ade20k_20200807_145416-a6daa92a.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x512_160k_ade20k/pspnet_s101-d8_512x512_160k_ade20k-20200807_145416.log.json) |
|
||||
| DeepLabV3 | S-101-D8 | 512x512 | 160000 | 14.6 | 9.28 | 45.71 | 46.59 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x512_160k_ade20k/deeplabv3_s101-d8_512x512_160k_ade20k_20200807_144503-17ecabe5.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x512_160k_ade20k/deeplabv3_s101-d8_512x512_160k_ade20k-20200807_144503.log.json) |
|
||||
| DeepLabV3+ | S-101-D8 | 512x512 | 160000 | 16.2 | 11.96 | 46.47 | 47.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k/deeplabv3plus_s101-d8_512x512_160k_ade20k_20200807_144503-27b26226.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k/deeplabv3plus_s101-d8_512x512_160k_ade20k-20200807_144503.log.json) |
|
@ -0,0 +1,9 @@
|
||||
_base_ = '../deeplabv3/deeplabv3_r101-d8_512x1024_80k_cityscapes.py'
|
||||
model = dict(
|
||||
pretrained='open-mmlab://resnest101',
|
||||
backbone=dict(
|
||||
type='ResNeSt',
|
||||
stem_channels=128,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True))
|
9
configs/resnest/deeplabv3_s101-d8_512x512_160k_ade20k.py
Normal file
9
configs/resnest/deeplabv3_s101-d8_512x512_160k_ade20k.py
Normal file
@ -0,0 +1,9 @@
|
||||
_base_ = '../deeplabv3/deeplabv3_r101-d8_512x512_160k_ade20k.py'
|
||||
model = dict(
|
||||
pretrained='open-mmlab://resnest101',
|
||||
backbone=dict(
|
||||
type='ResNeSt',
|
||||
stem_channels=128,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True))
|
@ -0,0 +1,9 @@
|
||||
_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py'
|
||||
model = dict(
|
||||
pretrained='open-mmlab://resnest101',
|
||||
backbone=dict(
|
||||
type='ResNeSt',
|
||||
stem_channels=128,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True))
|
@ -0,0 +1,9 @@
|
||||
_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x512_160k_ade20k.py'
|
||||
model = dict(
|
||||
pretrained='open-mmlab://resnest101',
|
||||
backbone=dict(
|
||||
type='ResNeSt',
|
||||
stem_channels=128,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True))
|
9
configs/resnest/fcn_s101-d8_512x1024_80k_cityscapes.py
Normal file
9
configs/resnest/fcn_s101-d8_512x1024_80k_cityscapes.py
Normal file
@ -0,0 +1,9 @@
|
||||
_base_ = '../fcn/fcn_r101-d8_512x1024_80k_cityscapes.py'
|
||||
model = dict(
|
||||
pretrained='open-mmlab://resnest101',
|
||||
backbone=dict(
|
||||
type='ResNeSt',
|
||||
stem_channels=128,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True))
|
9
configs/resnest/fcn_s101-d8_512x512_160k_ade20k.py
Normal file
9
configs/resnest/fcn_s101-d8_512x512_160k_ade20k.py
Normal file
@ -0,0 +1,9 @@
|
||||
_base_ = '../fcn/fcn_r101-d8_512x512_160k_ade20k.py'
|
||||
model = dict(
|
||||
pretrained='open-mmlab://resnest101',
|
||||
backbone=dict(
|
||||
type='ResNeSt',
|
||||
stem_channels=128,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True))
|
@ -0,0 +1,9 @@
|
||||
_base_ = '../pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py'
|
||||
model = dict(
|
||||
pretrained='open-mmlab://resnest101',
|
||||
backbone=dict(
|
||||
type='ResNeSt',
|
||||
stem_channels=128,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True))
|
9
configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py
Normal file
9
configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py
Normal file
@ -0,0 +1,9 @@
|
||||
_base_ = '../pspnet/pspnet_r101-d8_512x512_160k_ade20k.py'
|
||||
model = dict(
|
||||
pretrained='open-mmlab://resnest101',
|
||||
backbone=dict(
|
||||
type='ResNeSt',
|
||||
stem_channels=128,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True))
|
@ -81,6 +81,10 @@ Please refer to [ANN](https://github.com/open-mmlab/mmsegmentation/blob/master/c
|
||||
|
||||
Please refer to [OCRNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/ocrnet) for details.
|
||||
|
||||
### ResNeSt
|
||||
|
||||
Please refer to [ResNeSt](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/resnest) for details.
|
||||
|
||||
|
||||
### Mixed Precision (FP16) Training
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .hrnet import HRNet
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
|
||||
__all__ = ['ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet']
|
||||
__all__ = ['ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'ResNeSt']
|
||||
|
314
mmseg/models/backbones/resnest.py
Normal file
314
mmseg/models/backbones/resnest.py
Normal file
@ -0,0 +1,314 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNetV1d
|
||||
|
||||
|
||||
class RSoftmax(nn.Module):
|
||||
"""Radix Softmax module in ``SplitAttentionConv2d``.
|
||||
|
||||
Args:
|
||||
radix (int): Radix of input.
|
||||
groups (int): Groups of input.
|
||||
"""
|
||||
|
||||
def __init__(self, radix, groups):
|
||||
super().__init__()
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
|
||||
x = F.softmax(x, dim=1)
|
||||
x = x.reshape(batch, -1)
|
||||
else:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class SplitAttentionConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d in ResNeSt.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
||||
stride (int | tuple[int]): Same as nn.Conv2d.
|
||||
padding (int | tuple[int]): Same as nn.Conv2d.
|
||||
dilation (int | tuple[int]): Same as nn.Conv2d.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
dcn (dict): Config dict for DCN. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None):
|
||||
super(SplitAttentionConv2d, self).__init__()
|
||||
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
self.channels = channels
|
||||
self.with_dcn = dcn is not None
|
||||
self.dcn = dcn
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||
if self.with_dcn and not fallback_on_stride:
|
||||
assert conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
conv_cfg = dcn
|
||||
self.conv = build_conv_layer(
|
||||
conv_cfg,
|
||||
in_channels,
|
||||
channels * radix,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups * radix,
|
||||
bias=False)
|
||||
self.norm0_name, norm0 = build_norm_layer(
|
||||
norm_cfg, channels * radix, postfix=0)
|
||||
self.add_module(self.norm0_name, norm0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc1 = build_conv_layer(
|
||||
None, channels, inter_channels, 1, groups=self.groups)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, inter_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.fc2 = build_conv_layer(
|
||||
None, inter_channels, channels * radix, 1, groups=self.groups)
|
||||
self.rsoftmax = RSoftmax(radix, groups)
|
||||
|
||||
@property
|
||||
def norm0(self):
|
||||
"""nn.Module: the normalization layer named "norm0" """
|
||||
return getattr(self, self.norm0_name)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm0(x)
|
||||
x = self.relu(x)
|
||||
|
||||
batch, rchannel = x.shape[:2]
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
splits = x.view(batch, self.radix, -1, *x.shape[2:])
|
||||
gap = splits.sum(dim=1)
|
||||
else:
|
||||
gap = x
|
||||
gap = F.adaptive_avg_pool2d(gap, 1)
|
||||
gap = self.fc1(gap)
|
||||
|
||||
gap = self.norm1(gap)
|
||||
gap = self.relu(gap)
|
||||
|
||||
atten = self.fc2(gap)
|
||||
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
||||
|
||||
if self.radix > 1:
|
||||
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
|
||||
out = torch.sum(attens * splits, dim=1)
|
||||
else:
|
||||
out = atten * x
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
class Bottleneck(_Bottleneck):
|
||||
"""Bottleneck block for ResNeSt.
|
||||
|
||||
Args:
|
||||
inplane (int): Input planes of this block.
|
||||
planes (int): Middle planes of this block.
|
||||
groups (int): Groups of conv2.
|
||||
width_per_group (int): Width per group of conv2. 64x4d indicates
|
||||
``groups=64, width_per_group=4`` and 32x8d indicates
|
||||
``groups=32, width_per_group=8``.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Key word arguments for base class.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
"""Bottleneck block for ResNeSt."""
|
||||
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
|
||||
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.with_modulated_dcn = False
|
||||
self.conv2 = SplitAttentionConv2d(
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=1 if self.avg_down_stride else self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
radix=radix,
|
||||
reduction_factor=reduction_factor,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dcn=self.dcn)
|
||||
delattr(self, self.norm2_name)
|
||||
|
||||
if self.avg_down_stride:
|
||||
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
|
||||
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.avg_down_stride:
|
||||
out = self.avd_layer(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeSt(ResNetV1d):
|
||||
"""ResNeSt backbone.
|
||||
|
||||
Args:
|
||||
groups (int): Number of groups of Bottleneck. Default: 1
|
||||
base_width (int): Base width of Bottleneck. Default: 4
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Keyword arguments for ResNet.
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3)),
|
||||
200: (Bottleneck, (3, 24, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
self.radix = radix
|
||||
self.reduction_factor = reduction_factor
|
||||
self.avg_down_stride = avg_down_stride
|
||||
super(ResNeSt, self).__init__(**kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
radix=self.radix,
|
||||
reduction_factor=self.reduction_factor,
|
||||
avg_down_stride=self.avg_down_stride,
|
||||
**kwargs)
|
@ -42,8 +42,7 @@ class ResLayer(nn.Sequential):
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = []
|
||||
conv_stride = stride
|
||||
# check dilation for dilated ResNet
|
||||
if avg_down and (stride != 1 or dilation != 1):
|
||||
if avg_down:
|
||||
conv_stride = 1
|
||||
downsample.append(
|
||||
nn.AvgPool2d(
|
||||
|
@ -4,7 +4,8 @@ from mmcv.ops import DeformConv2dPack
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from torch.nn.modules import AvgPool2d, GroupNorm
|
||||
|
||||
from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt
|
||||
from mmseg.models.backbones import ResNeSt, ResNet, ResNetV1d, ResNeXt
|
||||
from mmseg.models.backbones.resnest import Bottleneck as BottleneckS
|
||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
from mmseg.models.utils import ResLayer
|
||||
@ -664,3 +665,41 @@ def test_resnext_backbone():
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
|
||||
def test_resnest_bottleneck():
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow')
|
||||
|
||||
# Test ResNeSt Bottleneck structure
|
||||
block = BottleneckS(
|
||||
64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch')
|
||||
assert block.avd_layer.stride == 2
|
||||
assert block.conv2.channels == 256
|
||||
|
||||
# Test ResNeSt Bottleneck forward
|
||||
block = BottleneckS(64, 16, radix=2, reduction_factor=4)
|
||||
x = torch.randn(2, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([2, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnest_backbone():
|
||||
with pytest.raises(KeyError):
|
||||
# ResNeSt depth should be in [50, 101, 152, 200]
|
||||
ResNeSt(depth=18)
|
||||
|
||||
# Test ResNeSt with radix 2, reduction_factor 4
|
||||
model = ResNeSt(
|
||||
depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([2, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([2, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([2, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([2, 2048, 7, 7])
|
||||
|
Loading…
x
Reference in New Issue
Block a user