diff --git a/README.md b/README.md index 63352663d..157a57b36 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,8 @@ Results and models are available in the [model zoo](docs/model_zoo.md). Supported backbones: - [x] ResNet - [x] ResNeXt -- [x] HRNet +- [x] [HRNet](configs/hrnet/README.md) +- [x] [ResNeSt](configs/resnest/README.md) Supported methods: - [x] [FCN](configs/fcn) diff --git a/configs/resnest/README.md b/configs/resnest/README.md new file mode 100644 index 000000000..4c876214f --- /dev/null +++ b/configs/resnest/README.md @@ -0,0 +1,30 @@ +# ResNeSt: Split-Attention Networks + +## Introduction + +``` +@article{zhang2020resnest, +title={ResNeSt: Split-Attention Networks}, +author={Zhang, Hang and Wu, Chongruo and Zhang, Zhongyue and Zhu, Yi and Zhang, Zhi and Lin, Haibin and Sun, Yue and He, Tong and Muller, Jonas and Manmatha, R. and Li, Mu and Smola, Alexander}, +journal={arXiv preprint arXiv:2004.08955}, +year={2020} +} +``` + +## Results and models + +### Cityscapes +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| FCN | S-101-D8 | 512x1024 | 80000 | 11.4 | 2.39 | 77.56 | 78.98 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x1024_80k_cityscapes/fcn_s101-d8_512x1024_80k_cityscapes_20200807_140631-f8d155b3.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x1024_80k_cityscapes/fcn_s101-d8_512x1024_80k_cityscapes-20200807_140631.log.json) | +| PSPNet | S-101-D8 | 512x1024 | 80000 | 11.8 | 2.52 | 78.57 | 79.19 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x1024_80k_cityscapes/pspnet_s101-d8_512x1024_80k_cityscapes_20200807_140631-c75f3b99.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x1024_80k_cityscapes/pspnet_s101-d8_512x1024_80k_cityscapes-20200807_140631.log.json) | +| DeepLabV3 | S-101-D8 | 512x1024 | 80000 | 11.9 | 1.88 | 79.67 | 80.51 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes/deeplabv3_s101-d8_512x1024_80k_cityscapes_20200807_144429-b73c4270.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes/deeplabv3_s101-d8_512x1024_80k_cityscapes-20200807_144429.log.json) | +| DeepLabV3+ | S-101-D8 | 512x1024 | 80000 | 13.2 | 2.36 | 79.62 | 80.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes/deeplabv3plus_s101-d8_512x1024_80k_cityscapes_20200807_144429-1239eb43.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes/deeplabv3plus_s101-d8_512x1024_80k_cityscapes-20200807_144429.log.json) | + +### ADE20k +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| FCN | S-101-D8 | 512x512 | 160000 | 14.2 | 12.86 | 45.62 | 46.16 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x512_160k_ade20k/fcn_s101-d8_512x512_160k_ade20k_20200807_145416-d3160329.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x512_160k_ade20k/fcn_s101-d8_512x512_160k_ade20k-20200807_145416.log.json) | +| PSPNet | S-101-D8 | 512x512 | 160000 | 14.2 | 13.02 | 45.44 | 46.28 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x512_160k_ade20k/pspnet_s101-d8_512x512_160k_ade20k_20200807_145416-a6daa92a.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x512_160k_ade20k/pspnet_s101-d8_512x512_160k_ade20k-20200807_145416.log.json) | +| DeepLabV3 | S-101-D8 | 512x512 | 160000 | 14.6 | 9.28 | 45.71 | 46.59 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x512_160k_ade20k/deeplabv3_s101-d8_512x512_160k_ade20k_20200807_144503-17ecabe5.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x512_160k_ade20k/deeplabv3_s101-d8_512x512_160k_ade20k-20200807_144503.log.json) | +| DeepLabV3+ | S-101-D8 | 512x512 | 160000 | 16.2 | 11.96 | 46.47 | 47.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k/deeplabv3plus_s101-d8_512x512_160k_ade20k_20200807_144503-27b26226.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k/deeplabv3plus_s101-d8_512x512_160k_ade20k-20200807_144503.log.json) | diff --git a/configs/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes.py b/configs/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..f98398690 --- /dev/null +++ b/configs/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,9 @@ +_base_ = '../deeplabv3/deeplabv3_r101-d8_512x1024_80k_cityscapes.py' +model = dict( + pretrained='open-mmlab://resnest101', + backbone=dict( + type='ResNeSt', + stem_channels=128, + radix=2, + reduction_factor=4, + avg_down_stride=True)) diff --git a/configs/resnest/deeplabv3_s101-d8_512x512_160k_ade20k.py b/configs/resnest/deeplabv3_s101-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..e3924ad67 --- /dev/null +++ b/configs/resnest/deeplabv3_s101-d8_512x512_160k_ade20k.py @@ -0,0 +1,9 @@ +_base_ = '../deeplabv3/deeplabv3_r101-d8_512x512_160k_ade20k.py' +model = dict( + pretrained='open-mmlab://resnest101', + backbone=dict( + type='ResNeSt', + stem_channels=128, + radix=2, + reduction_factor=4, + avg_down_stride=True)) diff --git a/configs/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes.py b/configs/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..69bef7238 --- /dev/null +++ b/configs/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,9 @@ +_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py' +model = dict( + pretrained='open-mmlab://resnest101', + backbone=dict( + type='ResNeSt', + stem_channels=128, + radix=2, + reduction_factor=4, + avg_down_stride=True)) diff --git a/configs/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k.py b/configs/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..d51bccb96 --- /dev/null +++ b/configs/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k.py @@ -0,0 +1,9 @@ +_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x512_160k_ade20k.py' +model = dict( + pretrained='open-mmlab://resnest101', + backbone=dict( + type='ResNeSt', + stem_channels=128, + radix=2, + reduction_factor=4, + avg_down_stride=True)) diff --git a/configs/resnest/fcn_s101-d8_512x1024_80k_cityscapes.py b/configs/resnest/fcn_s101-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..33fa0252d --- /dev/null +++ b/configs/resnest/fcn_s101-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,9 @@ +_base_ = '../fcn/fcn_r101-d8_512x1024_80k_cityscapes.py' +model = dict( + pretrained='open-mmlab://resnest101', + backbone=dict( + type='ResNeSt', + stem_channels=128, + radix=2, + reduction_factor=4, + avg_down_stride=True)) diff --git a/configs/resnest/fcn_s101-d8_512x512_160k_ade20k.py b/configs/resnest/fcn_s101-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..dcee8c280 --- /dev/null +++ b/configs/resnest/fcn_s101-d8_512x512_160k_ade20k.py @@ -0,0 +1,9 @@ +_base_ = '../fcn/fcn_r101-d8_512x512_160k_ade20k.py' +model = dict( + pretrained='open-mmlab://resnest101', + backbone=dict( + type='ResNeSt', + stem_channels=128, + radix=2, + reduction_factor=4, + avg_down_stride=True)) diff --git a/configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py b/configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..9737849cb --- /dev/null +++ b/configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,9 @@ +_base_ = '../pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py' +model = dict( + pretrained='open-mmlab://resnest101', + backbone=dict( + type='ResNeSt', + stem_channels=128, + radix=2, + reduction_factor=4, + avg_down_stride=True)) diff --git a/configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py b/configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..6a622eae9 --- /dev/null +++ b/configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py @@ -0,0 +1,9 @@ +_base_ = '../pspnet/pspnet_r101-d8_512x512_160k_ade20k.py' +model = dict( + pretrained='open-mmlab://resnest101', + backbone=dict( + type='ResNeSt', + stem_channels=128, + radix=2, + reduction_factor=4, + avg_down_stride=True)) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index ddca1f509..aa5c4eb62 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -81,6 +81,10 @@ Please refer to [ANN](https://github.com/open-mmlab/mmsegmentation/blob/master/c Please refer to [OCRNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/ocrnet) for details. +### ResNeSt + +Please refer to [ResNeSt](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/resnest) for details. + ### Mixed Precision (FP16) Training diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 367b398ce..35924248d 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -1,5 +1,6 @@ from .hrnet import HRNet +from .resnest import ResNeSt from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnext import ResNeXt -__all__ = ['ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet'] +__all__ = ['ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'ResNeSt'] diff --git a/mmseg/models/backbones/resnest.py b/mmseg/models/backbones/resnest.py new file mode 100644 index 000000000..8931decb8 --- /dev/null +++ b/mmseg/models/backbones/resnest.py @@ -0,0 +1,314 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d in ResNeSt. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels. Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + dcn (dict): Config dict for DCN. Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None): + super(SplitAttentionConv2d, self).__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.with_dcn = dcn is not None + self.dcn = dcn + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_dcn and not fallback_on_stride: + assert conv_cfg is None, 'conv_cfg must be None for DCN' + conv_cfg = dcn + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + """nn.Module: the normalization layer named "norm0" """ + return getattr(self, self.norm0_name) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + batch = x.size(0) + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + inplane (int): Input planes of this block. + planes (int): Middle planes of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Key word arguments for base class. + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + """Bottleneck block for ResNeSt.""" + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.with_modulated_dcn = False + self.conv2 = SplitAttentionConv2d( + width, + width, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=self.dcn) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + Args: + groups (int): Number of groups of Bottleneck. Default: 1 + base_width (int): Base width of Bottleneck. Default: 4 + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Keyword arguments for ResNet. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)) + } + + def __init__(self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.base_width = base_width + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super(ResNeSt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/mmseg/models/utils/res_layer.py b/mmseg/models/utils/res_layer.py index 9ef51b95b..2585ab551 100644 --- a/mmseg/models/utils/res_layer.py +++ b/mmseg/models/utils/res_layer.py @@ -42,8 +42,7 @@ class ResLayer(nn.Sequential): if stride != 1 or inplanes != planes * block.expansion: downsample = [] conv_stride = stride - # check dilation for dilated ResNet - if avg_down and (stride != 1 or dilation != 1): + if avg_down: conv_stride = 1 downsample.append( nn.AvgPool2d( diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 00ae43d00..ba6cdaa19 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -4,7 +4,8 @@ from mmcv.ops import DeformConv2dPack from mmcv.utils.parrots_wrapper import _BatchNorm from torch.nn.modules import AvgPool2d, GroupNorm -from mmseg.models.backbones import ResNet, ResNetV1d, ResNeXt +from mmseg.models.backbones import ResNeSt, ResNet, ResNetV1d, ResNeXt +from mmseg.models.backbones.resnest import Bottleneck as BottleneckS from mmseg.models.backbones.resnet import BasicBlock, Bottleneck from mmseg.models.backbones.resnext import Bottleneck as BottleneckX from mmseg.models.utils import ResLayer @@ -664,3 +665,41 @@ def test_resnext_backbone(): assert feat[1].shape == torch.Size([1, 512, 28, 28]) assert feat[2].shape == torch.Size([1, 1024, 14, 14]) assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + +def test_resnest_bottleneck(): + with pytest.raises(AssertionError): + # Style must be in ['pytorch', 'caffe'] + BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow') + + # Test ResNeSt Bottleneck structure + block = BottleneckS( + 64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch') + assert block.avd_layer.stride == 2 + assert block.conv2.channels == 256 + + # Test ResNeSt Bottleneck forward + block = BottleneckS(64, 16, radix=2, reduction_factor=4) + x = torch.randn(2, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([2, 64, 56, 56]) + + +def test_resnest_backbone(): + with pytest.raises(KeyError): + # ResNeSt depth should be in [50, 101, 152, 200] + ResNeSt(depth=18) + + # Test ResNeSt with radix 2, reduction_factor 4 + model = ResNeSt( + depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3)) + model.init_weights() + model.train() + + imgs = torch.randn(2, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([2, 256, 56, 56]) + assert feat[1].shape == torch.Size([2, 512, 28, 28]) + assert feat[2].shape == torch.Size([2, 1024, 14, 14]) + assert feat[3].shape == torch.Size([2, 2048, 7, 7])