mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Fast-SCNN implemented (#58)
* init commit: fast_scnn * 247917iters * 4x8_80k * configs placed in configs_unify. 4x8_80k exp.running. * mmseg/utils/collect_env.py modified to support Windows * study on lr * bug in configs_unify/***/cityscapes.py fixed. * lr0.08_100k * lr_power changed to 1.2 * log_config by_epoch set to False. * lr1.2 * doc strings added * add fast_scnn backbone test * 80k 0.08,0.12 * add 450k * fast_scnn test: fix BN bug. * Add different config files into configs/ * .gitignore recovered. * configs_unify del * .gitignore recovered. * delete sub-optimal config files of fast-scnn * Code style improved. * add docstrings to component modules of fast-scnn * relevant files modified according to Jerry's instructions * relevant files modified according to Jerry's instructions * lint problems fixed. * fast_scnn config extremely simplified. * InvertedResidual * fixed padding problems * add unit test for inverted_residual * add unit test for inverted_residual: debug 0 * add unit test for inverted_residual: debug 1 * add unit test for inverted_residual: debug 2 * add unit test for inverted_residual: debug 3 * add unit test for sep_fcn_head: debug 0 * add unit test for sep_fcn_head: debug 1 * add unit test for sep_fcn_head: debug 2 * add unit test for sep_fcn_head: debug 3 * add unit test for sep_fcn_head: debug 4 * add unit test for sep_fcn_head: debug 5 * FastSCNN type(dwchannels) changed to tuple. * t changed to expand_ratio. * Spaces fixed. * Update mmseg/models/backbones/fast_scnn.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * Update mmseg/models/decode_heads/sep_fcn_head.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * Update mmseg/models/decode_heads/sep_fcn_head.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * Docstrings fixed. * Docstrings fixed. * Inverted Residual kept coherent with mmcl. * Inverted Residual kept coherent with mmcl. Debug 0 * _make_layer parameters renamed. * final commit * Arg scale_factor deleted. * Expand_ratio docstrings updated. * final commit * Readme for Fast-SCNN added. * model-zoo.md modified. * fast_scnn README updated. * Move InvertedResidual module into mmseg/utils. * test_inverted_residual module corrected. * test_inverted_residual.py moved. * encoder_decoder modified to avoid bugs when running PSPNet. getting_started.md bug fixed. * Revert "encoder_decoder modified to avoid bugs when running PSPNet. " This reverts commit dd0aadfb Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
c8b250df4a
commit
f6b9da55f3
58
configs/_base_/models/fast_scnn.py
Normal file
58
configs/_base_/models/fast_scnn.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# model settings
|
||||||
|
norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
|
||||||
|
model = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
backbone=dict(
|
||||||
|
type='FastSCNN',
|
||||||
|
downsample_dw_channels=(32, 48),
|
||||||
|
global_in_channels=64,
|
||||||
|
global_block_channels=(64, 96, 128),
|
||||||
|
global_block_strides=(2, 2, 1),
|
||||||
|
global_out_channels=128,
|
||||||
|
higher_in_channels=64,
|
||||||
|
lower_in_channels=128,
|
||||||
|
fusion_out_channels=128,
|
||||||
|
out_indices=(0, 1, 2),
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False),
|
||||||
|
decode_head=dict(
|
||||||
|
type='DepthwiseSeparableFCNHead',
|
||||||
|
in_channels=128,
|
||||||
|
channels=128,
|
||||||
|
concat_input=False,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=-1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.)),
|
||||||
|
auxiliary_head=[
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=128,
|
||||||
|
channels=32,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=-2,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=64,
|
||||||
|
channels=32,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=-3,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||||
|
])
|
||||||
|
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg = dict()
|
||||||
|
test_cfg = dict(mode='whole')
|
18
configs/fastscnn/README.md
Normal file
18
configs/fastscnn/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Fast-SCNN for Semantic Segmentation
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
```
|
||||||
|
@article{poudel2019fast,
|
||||||
|
title={Fast-scnn: Fast semantic segmentation network},
|
||||||
|
author={Poudel, Rudra PK and Liwicki, Stephan and Cipolla, Roberto},
|
||||||
|
journal={arXiv preprint arXiv:1902.04502},
|
||||||
|
year={2019}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### Cityscapes
|
||||||
|
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||||
|
|------------|-----------|-----------|--------:|----------|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| Fast-SCNN | Fast-SCNN | 512x1024 | 80000 | 8.4 | 63.61 | 69.06 | - | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-cae6c46a.pth) | [log](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-20200807_165744.log.json) |
|
10
configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py
Normal file
10
configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/fast_scnn.py', '../_base_/datasets/cityscapes.py',
|
||||||
|
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Re-config the data sampler.
|
||||||
|
data = dict(samples_per_gpu=8, workers_per_gpu=4)
|
||||||
|
|
||||||
|
# Re-config the optimizer.
|
||||||
|
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5)
|
@ -338,7 +338,7 @@ The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pt
|
|||||||
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
|
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
|
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
||||||
|
@ -81,11 +81,14 @@ Please refer to [ANN](https://github.com/open-mmlab/mmsegmentation/blob/master/c
|
|||||||
|
|
||||||
Please refer to [OCRNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/ocrnet) for details.
|
Please refer to [OCRNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/ocrnet) for details.
|
||||||
|
|
||||||
|
### Fast-SCNN
|
||||||
|
|
||||||
|
Please refer to [Fast-SCNN](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fastscnn) for details.
|
||||||
|
|
||||||
### ResNeSt
|
### ResNeSt
|
||||||
|
|
||||||
Please refer to [ResNeSt](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/resnest) for details.
|
Please refer to [ResNeSt](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/resnest) for details.
|
||||||
|
|
||||||
|
|
||||||
### Mixed Precision (FP16) Training
|
### Mixed Precision (FP16) Training
|
||||||
|
|
||||||
Please refer [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fp16/README.md) for details.
|
Please refer [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fp16/README.md) for details.
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
|
from .fast_scnn import FastSCNN
|
||||||
from .hrnet import HRNet
|
from .hrnet import HRNet
|
||||||
from .resnest import ResNeSt
|
from .resnest import ResNeSt
|
||||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||||
from .resnext import ResNeXt
|
from .resnext import ResNeXt
|
||||||
|
|
||||||
__all__ = ['ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'ResNeSt']
|
__all__ = [
|
||||||
|
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||||
|
'ResNeSt'
|
||||||
|
]
|
||||||
|
385
mmseg/models/backbones/fast_scnn.py
Normal file
385
mmseg/models/backbones/fast_scnn.py
Normal file
@ -0,0 +1,385 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmseg.models.decode_heads.psp_head import PPM
|
||||||
|
from mmseg.ops import DepthwiseSeparableConvModule, resize
|
||||||
|
from mmseg.utils import InvertedResidual
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
|
||||||
|
|
||||||
|
class LearningToDownsample(nn.Module):
|
||||||
|
"""Learning to downsample module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
dw_channels (tuple[int]): Number of output channels of the first and
|
||||||
|
the second depthwise conv (dwconv) layers.
|
||||||
|
out_channels (int): Number of output channels of the whole
|
||||||
|
'learning to downsample' module.
|
||||||
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
|
dict(type='BN')
|
||||||
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
|
dict(type='ReLU')
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
dw_channels,
|
||||||
|
out_channels,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU')):
|
||||||
|
super(LearningToDownsample, self).__init__()
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
dw_channels1 = dw_channels[0]
|
||||||
|
dw_channels2 = dw_channels[1]
|
||||||
|
|
||||||
|
self.conv = ConvModule(
|
||||||
|
in_channels,
|
||||||
|
dw_channels1,
|
||||||
|
3,
|
||||||
|
stride=2,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
self.dsconv1 = DepthwiseSeparableConvModule(
|
||||||
|
dw_channels1,
|
||||||
|
dw_channels2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
norm_cfg=self.norm_cfg)
|
||||||
|
self.dsconv2 = DepthwiseSeparableConvModule(
|
||||||
|
dw_channels2,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
norm_cfg=self.norm_cfg)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.dsconv1(x)
|
||||||
|
x = self.dsconv2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalFeatureExtractor(nn.Module):
|
||||||
|
"""Global feature extractor module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels of the GFE module.
|
||||||
|
Default: 64
|
||||||
|
block_channels (tuple[int]): Tuple of ints. Each int specifies the
|
||||||
|
number of output channels of each Inverted Residual module.
|
||||||
|
Default: (64, 96, 128)
|
||||||
|
out_channels(int): Number of output channels of the GFE module.
|
||||||
|
Default: 128
|
||||||
|
expand_ratio (int): Adjusts number of channels of the hidden layer
|
||||||
|
in InvertedResidual by this amount.
|
||||||
|
Default: 6
|
||||||
|
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
||||||
|
number of times each Inverted Residual module is repeated.
|
||||||
|
The repeated Inverted Residual modules are called a 'group'.
|
||||||
|
Default: (3, 3, 3)
|
||||||
|
strides (tuple[int]): Tuple of ints. Each int specifies
|
||||||
|
the downsampling factor of each 'group'.
|
||||||
|
Default: (2, 2, 1)
|
||||||
|
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
||||||
|
the parameter required in 'global average pooling' within PPM.
|
||||||
|
Default: (1, 2, 3, 6)
|
||||||
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
|
dict(type='BN')
|
||||||
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
|
dict(type='ReLU')
|
||||||
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=64,
|
||||||
|
block_channels=(64, 96, 128),
|
||||||
|
out_channels=128,
|
||||||
|
expand_ratio=6,
|
||||||
|
num_blocks=(3, 3, 3),
|
||||||
|
strides=(2, 2, 1),
|
||||||
|
pool_scales=(1, 2, 3, 6),
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
align_corners=False):
|
||||||
|
super(GlobalFeatureExtractor, self).__init__()
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
assert len(block_channels) == len(num_blocks) == 3
|
||||||
|
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
||||||
|
num_blocks[0], strides[0],
|
||||||
|
expand_ratio)
|
||||||
|
self.bottleneck2 = self._make_layer(block_channels[0],
|
||||||
|
block_channels[1], num_blocks[1],
|
||||||
|
strides[1], expand_ratio)
|
||||||
|
self.bottleneck3 = self._make_layer(block_channels[1],
|
||||||
|
block_channels[2], num_blocks[2],
|
||||||
|
strides[2], expand_ratio)
|
||||||
|
self.ppm = PPM(
|
||||||
|
pool_scales,
|
||||||
|
block_channels[2],
|
||||||
|
block_channels[2] // 4,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
align_corners=align_corners)
|
||||||
|
self.out = ConvModule(
|
||||||
|
block_channels[2] * 2,
|
||||||
|
out_channels,
|
||||||
|
1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
|
||||||
|
def _make_layer(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
blocks,
|
||||||
|
stride=1,
|
||||||
|
expand_ratio=6):
|
||||||
|
layers = [
|
||||||
|
InvertedResidual(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
stride,
|
||||||
|
expand_ratio,
|
||||||
|
norm_cfg=self.norm_cfg)
|
||||||
|
]
|
||||||
|
for i in range(1, blocks):
|
||||||
|
layers.append(
|
||||||
|
InvertedResidual(
|
||||||
|
out_channels,
|
||||||
|
out_channels,
|
||||||
|
1,
|
||||||
|
expand_ratio,
|
||||||
|
norm_cfg=self.norm_cfg))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.bottleneck1(x)
|
||||||
|
x = self.bottleneck2(x)
|
||||||
|
x = self.bottleneck3(x)
|
||||||
|
x = torch.cat([x, *self.ppm(x)], dim=1)
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionModule(nn.Module):
|
||||||
|
"""Feature fusion module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
higher_in_channels (int): Number of input channels of the
|
||||||
|
higher-resolution branch.
|
||||||
|
lower_in_channels (int): Number of input channels of the
|
||||||
|
lower-resolution branch.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
scale_factor (int): Scale factor applied to the lower-res input.
|
||||||
|
Should be coherent with the downsampling factor determined
|
||||||
|
by the GFE module.
|
||||||
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
|
dict(type='BN')
|
||||||
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
|
dict(type='ReLU')
|
||||||
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
higher_in_channels,
|
||||||
|
lower_in_channels,
|
||||||
|
out_channels,
|
||||||
|
scale_factor,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
align_corners=False):
|
||||||
|
super(FeatureFusionModule, self).__init__()
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
self.align_corners = align_corners
|
||||||
|
self.dwconv = ConvModule(
|
||||||
|
lower_in_channels,
|
||||||
|
out_channels,
|
||||||
|
1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
self.conv_lower_res = ConvModule(
|
||||||
|
out_channels,
|
||||||
|
out_channels,
|
||||||
|
1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=None)
|
||||||
|
self.conv_higher_res = ConvModule(
|
||||||
|
higher_in_channels,
|
||||||
|
out_channels,
|
||||||
|
1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=None)
|
||||||
|
self.relu = nn.ReLU(True)
|
||||||
|
|
||||||
|
def forward(self, higher_res_feature, lower_res_feature):
|
||||||
|
lower_res_feature = resize(
|
||||||
|
lower_res_feature,
|
||||||
|
scale_factor=self.scale_factor,
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
lower_res_feature = self.dwconv(lower_res_feature)
|
||||||
|
lower_res_feature = self.conv_lower_res(lower_res_feature)
|
||||||
|
|
||||||
|
higher_res_feature = self.conv_higher_res(higher_res_feature)
|
||||||
|
out = higher_res_feature + lower_res_feature
|
||||||
|
return self.relu(out)
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class FastSCNN(nn.Module):
|
||||||
|
"""Fast-SCNN Backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input image channels. Default: 3.
|
||||||
|
downsample_dw_channels (tuple[int]): Number of output channels after
|
||||||
|
the first conv layer & the second conv layer in
|
||||||
|
Learning-To-Downsample (LTD) module.
|
||||||
|
Default: (32, 48).
|
||||||
|
global_in_channels (int): Number of input channels of
|
||||||
|
Global Feature Extractor(GFE).
|
||||||
|
Equal to number of output channels of LTD.
|
||||||
|
Default: 64.
|
||||||
|
global_block_channels (tuple[int]): Tuple of integers that describe
|
||||||
|
the output channels for each of the MobileNet-v2 bottleneck
|
||||||
|
residual blocks in GFE.
|
||||||
|
Default: (64, 96, 128).
|
||||||
|
global_block_strides (tuple[int]): Tuple of integers
|
||||||
|
that describe the strides (downsampling factors) for each of the
|
||||||
|
MobileNet-v2 bottleneck residual blocks in GFE.
|
||||||
|
Default: (2, 2, 1).
|
||||||
|
global_out_channels (int): Number of output channels of GFE.
|
||||||
|
Default: 128.
|
||||||
|
higher_in_channels (int): Number of input channels of the higher
|
||||||
|
resolution branch in FFM.
|
||||||
|
Equal to global_in_channels.
|
||||||
|
Default: 64.
|
||||||
|
lower_in_channels (int): Number of input channels of the lower
|
||||||
|
resolution branch in FFM.
|
||||||
|
Equal to global_out_channels.
|
||||||
|
Default: 128.
|
||||||
|
fusion_out_channels (int): Number of output channels of FFM.
|
||||||
|
Default: 128.
|
||||||
|
out_indices (tuple): Tuple of indices of list
|
||||||
|
[higher_res_features, lower_res_features, fusion_output].
|
||||||
|
Often set to (0,1,2) to enable aux. heads.
|
||||||
|
Default: (0, 1, 2).
|
||||||
|
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||||
|
norm_cfg (dict | None): Config of norm layers. Default:
|
||||||
|
dict(type='BN')
|
||||||
|
act_cfg (dict): Config of activation layers. Default:
|
||||||
|
dict(type='ReLU')
|
||||||
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=3,
|
||||||
|
downsample_dw_channels=(32, 48),
|
||||||
|
global_in_channels=64,
|
||||||
|
global_block_channels=(64, 96, 128),
|
||||||
|
global_block_strides=(2, 2, 1),
|
||||||
|
global_out_channels=128,
|
||||||
|
higher_in_channels=64,
|
||||||
|
lower_in_channels=128,
|
||||||
|
fusion_out_channels=128,
|
||||||
|
out_indices=(0, 1, 2),
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
align_corners=False):
|
||||||
|
|
||||||
|
super(FastSCNN, self).__init__()
|
||||||
|
if global_in_channels != higher_in_channels:
|
||||||
|
raise AssertionError('Global Input Channels must be the same \
|
||||||
|
with Higher Input Channels!')
|
||||||
|
elif global_out_channels != lower_in_channels:
|
||||||
|
raise AssertionError('Global Output Channels must be the same \
|
||||||
|
with Lower Input Channels!')
|
||||||
|
|
||||||
|
# Calculate scale factor used in FFM.
|
||||||
|
self.scale_factor = 1
|
||||||
|
for factor in global_block_strides:
|
||||||
|
self.scale_factor *= factor
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||||
|
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||||
|
self.global_in_channels = global_in_channels
|
||||||
|
self.global_block_channels = global_block_channels
|
||||||
|
self.global_block_strides = global_block_strides
|
||||||
|
self.global_out_channels = global_out_channels
|
||||||
|
self.higher_in_channels = higher_in_channels
|
||||||
|
self.lower_in_channels = lower_in_channels
|
||||||
|
self.fusion_out_channels = fusion_out_channels
|
||||||
|
self.out_indices = out_indices
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
self.align_corners = align_corners
|
||||||
|
self.learning_to_downsample = LearningToDownsample(
|
||||||
|
in_channels,
|
||||||
|
downsample_dw_channels,
|
||||||
|
global_in_channels,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
self.global_feature_extractor = GlobalFeatureExtractor(
|
||||||
|
global_in_channels,
|
||||||
|
global_block_channels,
|
||||||
|
global_out_channels,
|
||||||
|
strides=self.global_block_strides,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
self.feature_fusion = FeatureFusionModule(
|
||||||
|
higher_in_channels,
|
||||||
|
lower_in_channels,
|
||||||
|
fusion_out_channels,
|
||||||
|
scale_factor=self.scale_factor,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
kaiming_init(m)
|
||||||
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||||
|
constant_init(m, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
higher_res_features = self.learning_to_downsample(x)
|
||||||
|
lower_res_features = self.global_feature_extractor(higher_res_features)
|
||||||
|
fusion_output = self.feature_fusion(higher_res_features,
|
||||||
|
lower_res_features)
|
||||||
|
|
||||||
|
outs = [higher_res_features, lower_res_features, fusion_output]
|
||||||
|
outs = [outs[i] for i in self.out_indices]
|
||||||
|
return tuple(outs)
|
@ -10,10 +10,11 @@ from .ocr_head import OCRHead
|
|||||||
from .psa_head import PSAHead
|
from .psa_head import PSAHead
|
||||||
from .psp_head import PSPHead
|
from .psp_head import PSPHead
|
||||||
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
||||||
|
from .sep_fcn_head import DepthwiseSeparableFCNHead
|
||||||
from .uper_head import UPerHead
|
from .uper_head import UPerHead
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||||
'EncHead'
|
'EncHead', 'DepthwiseSeparableFCNHead'
|
||||||
]
|
]
|
||||||
|
@ -27,6 +27,7 @@ class FCNHead(BaseDecodeHead):
|
|||||||
assert num_convs > 0
|
assert num_convs > 0
|
||||||
self.num_convs = num_convs
|
self.num_convs = num_convs
|
||||||
self.concat_input = concat_input
|
self.concat_input = concat_input
|
||||||
|
self.kernel_size = kernel_size
|
||||||
super(FCNHead, self).__init__(**kwargs)
|
super(FCNHead, self).__init__(**kwargs)
|
||||||
convs = []
|
convs = []
|
||||||
convs.append(
|
convs.append(
|
||||||
|
50
mmseg/models/decode_heads/sep_fcn_head.py
Normal file
50
mmseg/models/decode_heads/sep_fcn_head.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from mmseg.ops import DepthwiseSeparableConvModule
|
||||||
|
from ..builder import HEADS
|
||||||
|
from .fcn_head import FCNHead
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class DepthwiseSeparableFCNHead(FCNHead):
|
||||||
|
"""Depthwise-Separable Fully Convolutional Network for Semantic
|
||||||
|
Segmentation.
|
||||||
|
|
||||||
|
This head is implemented according to Fast-SCNN paper.
|
||||||
|
Args:
|
||||||
|
in_channels(int): Number of output channels of FFM.
|
||||||
|
channels(int): Number of middle-stage channels in the decode head.
|
||||||
|
concat_input(bool): Whether to concatenate original decode input into
|
||||||
|
the result of several consecutive convolution layers.
|
||||||
|
Default: True.
|
||||||
|
num_classes(int): Used to determine the dimension of
|
||||||
|
final prediction tensor.
|
||||||
|
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
|
||||||
|
norm_cfg (dict | None): Config of norm layers.
|
||||||
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
Default: False.
|
||||||
|
loss_decode(dict): Config of loss type and some
|
||||||
|
relevant additional options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
|
||||||
|
self.convs[0] = DepthwiseSeparableConvModule(
|
||||||
|
self.in_channels,
|
||||||
|
self.channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.kernel_size // 2,
|
||||||
|
norm_cfg=self.norm_cfg)
|
||||||
|
for i in range(1, self.num_convs):
|
||||||
|
self.convs[i] = DepthwiseSeparableConvModule(
|
||||||
|
self.channels,
|
||||||
|
self.channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.kernel_size // 2,
|
||||||
|
norm_cfg=self.norm_cfg)
|
||||||
|
|
||||||
|
if self.concat_input:
|
||||||
|
self.conv_cat = DepthwiseSeparableConvModule(
|
||||||
|
self.in_channels + self.channels,
|
||||||
|
self.channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.kernel_size // 2,
|
||||||
|
norm_cfg=self.norm_cfg)
|
@ -1,7 +1,5 @@
|
|||||||
from .collect_env import collect_env
|
from .collect_env import collect_env
|
||||||
|
from .inverted_residual_module import InvertedResidual
|
||||||
from .logger import get_root_logger
|
from .logger import get_root_logger
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ['get_root_logger', 'collect_env', 'InvertedResidual']
|
||||||
'get_root_logger',
|
|
||||||
'collect_env',
|
|
||||||
]
|
|
||||||
|
73
mmseg/utils/inverted_residual_module.py
Normal file
73
mmseg/utils/inverted_residual_module.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from mmcv.cnn import ConvModule, build_norm_layer
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
"""Inverted residual module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The input channels of the InvertedResidual block.
|
||||||
|
out_channels (int): The output channels of the InvertedResidual block.
|
||||||
|
stride (int): Stride of the middle (first) 3x3 convolution.
|
||||||
|
expand_ratio (int): adjusts number of channels of the hidden layer
|
||||||
|
in InvertedResidual by this amount.
|
||||||
|
conv_cfg (dict): Config dict for convolution layer.
|
||||||
|
Default: None, which means using conv2d.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict): Config dict for activation layer.
|
||||||
|
Default: dict(type='ReLU6').
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
stride,
|
||||||
|
expand_ratio,
|
||||||
|
dilation=1,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU6')):
|
||||||
|
super(InvertedResidual, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
assert stride in [1, 2]
|
||||||
|
|
||||||
|
hidden_dim = int(round(in_channels * expand_ratio))
|
||||||
|
self.use_res_connect = self.stride == 1 \
|
||||||
|
and in_channels == out_channels
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
if expand_ratio != 1:
|
||||||
|
# pw
|
||||||
|
layers.append(
|
||||||
|
ConvModule(
|
||||||
|
in_channels,
|
||||||
|
hidden_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg))
|
||||||
|
layers.extend([
|
||||||
|
# dw
|
||||||
|
ConvModule(
|
||||||
|
hidden_dim,
|
||||||
|
hidden_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=dilation,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=hidden_dim,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias=False),
|
||||||
|
build_norm_layer(norm_cfg, out_channels)[1],
|
||||||
|
])
|
||||||
|
self.conv = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.conv(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x)
|
@ -4,7 +4,8 @@ from mmcv.ops import DeformConv2dPack
|
|||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
from torch.nn.modules import AvgPool2d, GroupNorm
|
from torch.nn.modules import AvgPool2d, GroupNorm
|
||||||
|
|
||||||
from mmseg.models.backbones import ResNeSt, ResNet, ResNetV1d, ResNeXt
|
from mmseg.models.backbones import (FastSCNN, ResNeSt, ResNet, ResNetV1d,
|
||||||
|
ResNeXt)
|
||||||
from mmseg.models.backbones.resnest import Bottleneck as BottleneckS
|
from mmseg.models.backbones.resnest import Bottleneck as BottleneckS
|
||||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||||
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
||||||
@ -48,7 +49,6 @@ def check_norm_state(modules, train_state):
|
|||||||
|
|
||||||
|
|
||||||
def test_resnet_basic_block():
|
def test_resnet_basic_block():
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# Not implemented yet.
|
# Not implemented yet.
|
||||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||||
@ -98,7 +98,6 @@ def test_resnet_basic_block():
|
|||||||
|
|
||||||
|
|
||||||
def test_resnet_bottleneck():
|
def test_resnet_bottleneck():
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# Style must be in ['pytorch', 'caffe']
|
# Style must be in ['pytorch', 'caffe']
|
||||||
Bottleneck(64, 64, style='tensorflow')
|
Bottleneck(64, 64, style='tensorflow')
|
||||||
@ -667,6 +666,33 @@ def test_resnext_backbone():
|
|||||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||||
|
|
||||||
|
|
||||||
|
def test_fastscnn_backbone():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# Fast-SCNN channel constraints.
|
||||||
|
FastSCNN(
|
||||||
|
3, (32, 48),
|
||||||
|
64, (64, 96, 128), (2, 2, 1),
|
||||||
|
global_out_channels=127,
|
||||||
|
higher_in_channels=64,
|
||||||
|
lower_in_channels=128)
|
||||||
|
|
||||||
|
# Test FastSCNN Standard Forward
|
||||||
|
model = FastSCNN()
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
batch_size = 4
|
||||||
|
imgs = torch.randn(batch_size, 3, 512, 1024)
|
||||||
|
feat = model(imgs)
|
||||||
|
|
||||||
|
assert len(feat) == 3
|
||||||
|
# higher-res
|
||||||
|
assert feat[0].shape == torch.Size([batch_size, 64, 64, 128])
|
||||||
|
# lower-res
|
||||||
|
assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
|
||||||
|
# FFM output
|
||||||
|
assert feat[2].shape == torch.Size([batch_size, 128, 64, 128])
|
||||||
|
|
||||||
|
|
||||||
def test_resnest_bottleneck():
|
def test_resnest_bottleneck():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# Style must be in ['pytorch', 'caffe']
|
# Style must be in ['pytorch', 'caffe']
|
||||||
|
@ -6,7 +6,8 @@ from mmcv.cnn import ConvModule
|
|||||||
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
||||||
|
|
||||||
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
||||||
DepthwiseSeparableASPPHead, EncHead,
|
DepthwiseSeparableASPPHead,
|
||||||
|
DepthwiseSeparableFCNHead, EncHead,
|
||||||
FCNHead, GCHead, NLHead, OCRHead,
|
FCNHead, GCHead, NLHead, OCRHead,
|
||||||
PSAHead, PSPHead, UPerHead)
|
PSAHead, PSPHead, UPerHead)
|
||||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||||
@ -539,3 +540,37 @@ def test_dw_aspp_head():
|
|||||||
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
||||||
outputs = head(inputs)
|
outputs = head(inputs)
|
||||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sep_fcn_head():
|
||||||
|
# test sep_fcn_head with concat_input=False
|
||||||
|
head = DepthwiseSeparableFCNHead(
|
||||||
|
in_channels=128,
|
||||||
|
channels=128,
|
||||||
|
concat_input=False,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=-1,
|
||||||
|
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
|
||||||
|
x = [torch.rand(2, 128, 32, 32)]
|
||||||
|
output = head(x)
|
||||||
|
assert output.shape == (2, head.num_classes, 32, 32)
|
||||||
|
assert not head.concat_input
|
||||||
|
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
|
||||||
|
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||||
|
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||||
|
assert head.conv_seg.kernel_size == (1, 1)
|
||||||
|
|
||||||
|
head = DepthwiseSeparableFCNHead(
|
||||||
|
in_channels=64,
|
||||||
|
channels=64,
|
||||||
|
concat_input=True,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=-1,
|
||||||
|
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
|
||||||
|
x = [torch.rand(3, 64, 32, 32)]
|
||||||
|
output = head(x)
|
||||||
|
assert output.shape == (3, head.num_classes, 32, 32)
|
||||||
|
assert head.concat_input
|
||||||
|
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
|
||||||
|
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||||
|
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||||
|
40
tests/test_utils/test_inverted_residual_module.py
Normal file
40
tests/test_utils/test_inverted_residual_module.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmseg.utils import InvertedResidual
|
||||||
|
|
||||||
|
|
||||||
|
def test_inv_residual():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# test stride assertion.
|
||||||
|
InvertedResidual(32, 32, 3, 4)
|
||||||
|
|
||||||
|
# test default config with res connection.
|
||||||
|
# set expand_ratio = 4, stride = 1 and inp=oup.
|
||||||
|
inv_module = InvertedResidual(32, 32, 1, 4)
|
||||||
|
assert inv_module.use_res_connect
|
||||||
|
assert inv_module.conv[0].kernel_size == (1, 1)
|
||||||
|
assert inv_module.conv[0].padding == 0
|
||||||
|
assert inv_module.conv[1].kernel_size == (3, 3)
|
||||||
|
assert inv_module.conv[1].padding == 1
|
||||||
|
assert inv_module.conv[0].with_norm
|
||||||
|
assert inv_module.conv[1].with_norm
|
||||||
|
x = torch.rand(1, 32, 64, 64)
|
||||||
|
output = inv_module(x)
|
||||||
|
assert output.shape == (1, 32, 64, 64)
|
||||||
|
|
||||||
|
# test inv_residual module without res connection.
|
||||||
|
# set expand_ratio = 4, stride = 2.
|
||||||
|
inv_module = InvertedResidual(32, 32, 2, 4)
|
||||||
|
assert not inv_module.use_res_connect
|
||||||
|
assert inv_module.conv[0].kernel_size == (1, 1)
|
||||||
|
x = torch.rand(1, 32, 64, 64)
|
||||||
|
output = inv_module(x)
|
||||||
|
assert output.shape == (1, 32, 32, 32)
|
||||||
|
|
||||||
|
# test expand_ratio == 1
|
||||||
|
inv_module = InvertedResidual(32, 32, 1, 1)
|
||||||
|
assert inv_module.conv[0].kernel_size == (3, 3)
|
||||||
|
x = torch.rand(1, 32, 64, 64)
|
||||||
|
output = inv_module(x)
|
||||||
|
assert output.shape == (1, 32, 64, 64)
|
Loading…
x
Reference in New Issue
Block a user