[New model] Support MobileNetV3 (#268)
* delete markdownlint * Support MobileNetV3 * fix import * add mobilenetv3 head and configs * Modify MobileNetV3 to semantic segmentation version * modify mobilenetv3 configs * add std configs * fix Conv2dAdaptivePadding bug * add configs * add unitest and fix bugs * fix lraspp unitest bugs * restore * fix unitest * add MobileNetV3 docstring * add mmcv * add mmcv * fix syntax bug * fix unitest bug * fix unitest bug * fix unitest bugs * fix docstring * add configs * restore * delete unnecessary assert * modify unitest * delete benchmarkpull/319/head
parent
5dacca3ea8
commit
7fdb4002fa
|
@ -60,6 +60,7 @@ Supported backbones:
|
|||
- [x] [HRNet](configs/hrnet/README.md)
|
||||
- [x] [ResNeSt](configs/resnest/README.md)
|
||||
- [x] [MobileNetV2](configs/mobilenet_v2/README.md)
|
||||
- [x] [MobileNetV3](configs/mobilenet_v3/README.md)
|
||||
|
||||
Supported methods:
|
||||
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
type='MobileNetV3',
|
||||
arch='large',
|
||||
out_indices=(1, 3, 16),
|
||||
norm_cfg=norm_cfg),
|
||||
decode_head=dict(
|
||||
type='LRASPPHead',
|
||||
in_channels=(16, 24, 960),
|
||||
in_index=(0, 1, 2),
|
||||
channels=128,
|
||||
input_transform='multiple_select',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
|
||||
# model training and testing settings
|
||||
train_cfg = dict()
|
||||
test_cfg = dict(mode='whole')
|
|
@ -0,0 +1,26 @@
|
|||
# Searching for MobileNetV3
|
||||
|
||||
## Introduction
|
||||
|
||||
```latex
|
||||
@inproceedings{Howard_2019_ICCV,
|
||||
title={Searching for MobileNetV3},
|
||||
author={Howard, Andrew and Sandler, Mark and Chu, Grace and Chen, Liang-Chieh and Chen, Bo and Tan, Mingxing and Wang, Weijun and Zhu, Yukun and Pang, Ruoming and Vasudevan, Vijay and Le, Quoc V. and Adam, Hartwig},
|
||||
booktitle={The IEEE International Conference on Computer Vision (ICCV)},
|
||||
pages={1314-1324},
|
||||
month={October},
|
||||
year={2019},
|
||||
doi={10.1109/ICCV.2019.00140}}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| LRASPP | M-V3-D8 | 512x1024 | 320000 | 8.9 | 15.22 | 69.54 | 70.89 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes_20201224_220337-cfe8fb07.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes-20201224_220337.log.json)|
|
||||
| LRASPP | M-V3-D8 (scratch) | 512x1024 | 320000 | 8.9 | 14.77 | 67.87 | 69.78 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes_20201224_220337-9f29cd72.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes-20201224_220337.log.json)|
|
||||
| LRASPP | M-V3s-D8 | 512x1024 | 320000 | 5.3 | 23.64 | 64.11 | 66.42 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes_20201224_223935-61565b34.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes-20201224_223935.log.json)|
|
||||
| LRASPP | M-V3s-D8 (scratch) | 512x1024 | 320000 | 5.3 | 24.50 | 62.74 | 65.01 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes_20201224_223935-03daeabb.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes-20201224_223935.log.json)|
|
|
@ -0,0 +1,11 @@
|
|||
_base_ = [
|
||||
'../_base_/models/lraspp_m-v3-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
|
||||
model = dict(pretrained='open-mmlab://contrib/mobilenet_v3_large')
|
||||
|
||||
# Re-config the data sampler.
|
||||
data = dict(samples_per_gpu=4, workers_per_gpu=4)
|
||||
|
||||
runner = dict(type='IterBasedRunner', max_iters=320000)
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/lraspp_m-v3-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
|
||||
# Re-config the data sampler.
|
||||
data = dict(samples_per_gpu=4, workers_per_gpu=4)
|
||||
|
||||
runner = dict(type='IterBasedRunner', max_iters=320000)
|
|
@ -0,0 +1,23 @@
|
|||
_base_ = './lraspp_m-v3-d8_512x1024_320k_cityscapes.py'
|
||||
norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://contrib/mobilenet_v3_small',
|
||||
backbone=dict(
|
||||
type='MobileNetV3',
|
||||
arch='small',
|
||||
out_indices=(0, 1, 12),
|
||||
norm_cfg=norm_cfg),
|
||||
decode_head=dict(
|
||||
type='LRASPPHead',
|
||||
in_channels=(16, 16, 576),
|
||||
in_index=(0, 1, 2),
|
||||
channels=128,
|
||||
input_transform='multiple_select',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
|
|
@ -0,0 +1,22 @@
|
|||
_base_ = './lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes.py'
|
||||
norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
type='MobileNetV3',
|
||||
arch='small',
|
||||
out_indices=(0, 1, 12),
|
||||
norm_cfg=norm_cfg),
|
||||
decode_head=dict(
|
||||
type='LRASPPHead',
|
||||
in_channels=(16, 16, 576),
|
||||
in_index=(0, 1, 2),
|
||||
channels=128,
|
||||
input_transform='multiple_select',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
|
|
@ -111,6 +111,10 @@ Please refer to [PointRend](https://github.com/open-mmlab/mmsegmentation/blob/ma
|
|||
|
||||
Please refer to [MobileNetV2](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/mobilenet_v2) for details.
|
||||
|
||||
### MobileNetV3
|
||||
|
||||
Please refer to [MobileNetV3](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/mobilenet_v3) for details.
|
||||
|
||||
### EMANet
|
||||
|
||||
Please refer to [EMANet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/emanet) for details.
|
||||
|
|
|
@ -2,6 +2,7 @@ from .cgnet import CGNet
|
|||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
|
@ -9,5 +10,5 @@ from .unet import UNet
|
|||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet'
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,255 @@
|
|||
import logging
|
||||
|
||||
import mmcv
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import InvertedResidualV3 as InvertedResidual
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class MobileNetV3(nn.Module):
|
||||
"""MobileNetV3 backbone.
|
||||
|
||||
This backbone is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
arch (str): Architechture of mobilnetv3, from {'small', 'large'}.
|
||||
Default: 'small'.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
out_indices (tuple[int]): Output from which layer.
|
||||
Default: (0, 1, 12).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Defualt: -1, which means not freezing any parameters.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed.
|
||||
Defualt: False.
|
||||
"""
|
||||
# Parameters to build each block:
|
||||
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
||||
arch_settings = {
|
||||
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
|
||||
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
|
||||
[3, 88, 24, False, 'ReLU', 1],
|
||||
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
|
||||
[5, 144, 48, True, 'HSwish', 1],
|
||||
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
|
||||
[5, 576, 96, True, 'HSwish', 1],
|
||||
[5, 576, 96, True, 'HSwish', 1]],
|
||||
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
|
||||
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
|
||||
[3, 72, 24, False, 'ReLU', 1],
|
||||
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
|
||||
[3, 200, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
|
||||
[3, 672, 112, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
|
||||
[5, 960, 160, True, 'HSwish', 1],
|
||||
[5, 960, 160, True, 'HSwish', 1]]
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='small',
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
out_indices=(0, 1, 12),
|
||||
frozen_stages=-1,
|
||||
reduction_factor=1,
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(MobileNetV3, self).__init__()
|
||||
assert arch in self.arch_settings
|
||||
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
||||
assert mmcv.is_tuple_of(out_indices, int)
|
||||
for index in out_indices:
|
||||
if index not in range(0, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError(
|
||||
'the item in out_indices must in '
|
||||
f'range(0, {len(self.arch_settings[arch])+2}). '
|
||||
f'But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError('frozen_stages must be in range(-1, '
|
||||
f'{len(self.arch_settings[arch])+2}). '
|
||||
f'But received {frozen_stages}')
|
||||
self.arch = arch
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.reduction_factor = reduction_factor
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.layers = self._make_layer()
|
||||
|
||||
def _make_layer(self):
|
||||
layers = []
|
||||
|
||||
# build the first layer (layer0)
|
||||
in_channels = 16
|
||||
layer = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=dict(type='Conv2dAdaptivePadding'),
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
self.add_module('layer0', layer)
|
||||
layers.append('layer0')
|
||||
|
||||
layer_setting = self.arch_settings[self.arch]
|
||||
for i, params in enumerate(layer_setting):
|
||||
(kernel_size, mid_channels, out_channels, with_se, act,
|
||||
stride) = params
|
||||
|
||||
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
|
||||
i >= 8:
|
||||
mid_channels = mid_channels // self.reduction_factor
|
||||
out_channels = out_channels // self.reduction_factor
|
||||
|
||||
if with_se:
|
||||
se_cfg = dict(
|
||||
channels=mid_channels,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
||||
else:
|
||||
se_cfg = None
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
mid_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
se_cfg=se_cfg,
|
||||
with_expand_conv=(in_channels != mid_channels),
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type=act),
|
||||
with_cp=self.with_cp)
|
||||
in_channels = out_channels
|
||||
layer_name = 'layer{}'.format(i + 1)
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# build the last layer
|
||||
# block5 layer12 os=32 for small model
|
||||
# block6 layer16 os=32 for large model
|
||||
layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=576 if self.arch == 'small' else 960,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
dilation=4,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
layer_name = 'layer{}'.format(len(layer_setting) + 1)
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# next, convert backbone MobileNetV3 to a semantic segmentation version
|
||||
if self.arch == 'small':
|
||||
self.layer4.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer9.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(4, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 9:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
else:
|
||||
self.layer7.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer13.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(7, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 13:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
|
||||
return layers
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(MobileNetV3, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
|
@ -10,6 +10,7 @@ from .enc_head import EncHead
|
|||
from .fcn_head import FCNHead
|
||||
from .fpn_head import FPNHead
|
||||
from .gc_head import GCHead
|
||||
from .lraspp_head import LRASPPHead
|
||||
from .nl_head import NLHead
|
||||
from .ocr_head import OCRHead
|
||||
from .point_head import PointHead
|
||||
|
@ -23,5 +24,5 @@ __all__ = [
|
|||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead', 'APCHead', 'DMHead'
|
||||
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv import is_tuple_of
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class LRASPPHead(BaseDecodeHead):
|
||||
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
|
||||
|
||||
This head is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
branch_channels (tuple[int]): The number of output channels in every
|
||||
each branch. Default: (32, 64).
|
||||
"""
|
||||
|
||||
def __init__(self, branch_channels=(32, 64), **kwargs):
|
||||
super(LRASPPHead, self).__init__(**kwargs)
|
||||
if self.input_transform != 'multiple_select':
|
||||
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
|
||||
f'must be \'multiple_select\'. But received '
|
||||
f'\'{self.input_transform}\'')
|
||||
assert is_tuple_of(branch_channels, int)
|
||||
assert len(branch_channels) == len(self.in_channels) - 1
|
||||
self.branch_channels = branch_channels
|
||||
|
||||
self.convs = nn.Sequential()
|
||||
self.conv_ups = nn.Sequential()
|
||||
for i in range(len(branch_channels)):
|
||||
self.convs.add_module(
|
||||
f'conv{i}',
|
||||
nn.Conv2d(
|
||||
self.in_channels[i], branch_channels[i], 1, bias=False))
|
||||
self.conv_ups.add_module(
|
||||
f'conv_up{i}',
|
||||
ConvModule(
|
||||
self.channels + branch_channels[i],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False))
|
||||
|
||||
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
|
||||
|
||||
self.aspp_conv = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False)
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
|
||||
ConvModule(
|
||||
self.in_channels[2],
|
||||
self.channels,
|
||||
1,
|
||||
act_cfg=dict(type='Sigmoid'),
|
||||
bias=False))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
x = inputs[-1]
|
||||
|
||||
x = self.aspp_conv(x) * resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.conv_up_input(x)
|
||||
|
||||
for i in range(len(self.branch_channels) - 1, -1, -1):
|
||||
x = resize(
|
||||
x,
|
||||
size=inputs[i].size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = torch.cat([x, self.convs[i](inputs[i])], 1)
|
||||
x = self.conv_ups[i](x)
|
||||
|
||||
return self.cls_seg(x)
|
|
@ -1,4 +1,4 @@
|
|||
from .inverted_residual import InvertedResidual
|
||||
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
||||
from .make_divisible import make_divisible
|
||||
from .res_layer import ResLayer
|
||||
from .self_attention_block import SelfAttentionBlock
|
||||
|
@ -6,5 +6,5 @@ from .up_conv_block import UpConvBlock
|
|||
|
||||
__all__ = [
|
||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||
'UpConvBlock'
|
||||
'UpConvBlock', 'InvertedResidualV3'
|
||||
]
|
||||
|
|
|
@ -2,6 +2,8 @@ from mmcv.cnn import ConvModule
|
|||
from torch import nn as nn
|
||||
from torch.utils import checkpoint as cp
|
||||
|
||||
from .se_layer import SELayer
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
"""InvertedResidual block for MobileNetV2.
|
||||
|
@ -23,7 +25,7 @@ class InvertedResidual(nn.Module):
|
|||
memory while slowing down the training speed. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -90,3 +92,117 @@ class InvertedResidual(nn.Module):
|
|||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidualV3(nn.Module):
|
||||
"""Inverted Residual Block for MobileNetV3.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
mid_channels (int): The input channels of the depthwise convolution.
|
||||
kernel_size (int): The kernal size of the depthwise convolution.
|
||||
Default: 3.
|
||||
stride (int): The stride of the depthwise convolution. Default: 1.
|
||||
se_cfg (dict): Config dict for se layer. Defaul: None, which means no
|
||||
se layer.
|
||||
with_expand_conv (bool): Use expand conv or not. If set False,
|
||||
mid_channels must be the same with in_channels. Default: True.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
se_cfg=None,
|
||||
with_expand_conv=True,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_cp=False):
|
||||
super(InvertedResidualV3, self).__init__()
|
||||
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
|
||||
assert stride in [1, 2]
|
||||
self.with_cp = with_cp
|
||||
self.with_se = se_cfg is not None
|
||||
self.with_expand_conv = with_expand_conv
|
||||
|
||||
if self.with_se:
|
||||
assert isinstance(se_cfg, dict)
|
||||
if not self.with_expand_conv:
|
||||
assert mid_channels == in_channels
|
||||
|
||||
if self.with_expand_conv:
|
||||
self.expand_conv = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.depthwise_conv = ConvModule(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=kernel_size // 2,
|
||||
groups=mid_channels,
|
||||
conv_cfg=dict(
|
||||
type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
if self.with_se:
|
||||
self.se = SELayer(**se_cfg)
|
||||
|
||||
self.linear_conv = ConvModule(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
out = x
|
||||
|
||||
if self.with_expand_conv:
|
||||
out = self.expand_conv(out)
|
||||
|
||||
out = self.depthwise_conv(out)
|
||||
|
||||
if self.with_se:
|
||||
out = self.se(out)
|
||||
|
||||
out = self.linear_conv(out)
|
||||
|
||||
if self.with_res_shortcut:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
|
||||
"""Make divisible function.
|
||||
|
||||
This function rounds the channel number down to the nearest value that can
|
||||
be divisible by the divisor.
|
||||
This function rounds the channel number to the nearest value that can be
|
||||
divisible by the divisor. It is taken from the original tf repo. It ensures
|
||||
that all layers have a channel number that is divisible by divisor. It can
|
||||
be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
|
||||
|
||||
Args:
|
||||
value (int): The original channel number.
|
||||
divisor (int): The divisor to fully divide the channel number.
|
||||
min_value (int, optional): The minimum value of the output channel.
|
||||
min_value (int): The minimum value of the output channel.
|
||||
Default: None, means that the minimum value equal to the divisor.
|
||||
min_ratio (float, optional): The minimum ratio of the rounded channel
|
||||
number to the original channel number. Default: 0.9.
|
||||
min_ratio (float): The minimum ratio of the rounded channel number to
|
||||
the original channel number. Default: 0.9.
|
||||
|
||||
Returns:
|
||||
int: The modified output channel number
|
||||
int: The modified output channel number.
|
||||
"""
|
||||
|
||||
if min_value is None:
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
import mmcv
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from .make_divisible import make_divisible
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
"""Squeeze-and-Excitation Module.
|
||||
|
||||
Args:
|
||||
channels (int): The input (and output) channels of the SE layer.
|
||||
ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
|
||||
``int(channels/ratio)``. Default: 16.
|
||||
conv_cfg (None or dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
act_cfg (dict or Sequence[dict]): Config dict for activation layer.
|
||||
If act_cfg is a dict, two activation layers will be configurated
|
||||
by this dict. If act_cfg is a sequence of dicts, the first
|
||||
activation layer will be configurated by the first dict and the
|
||||
second activation layer will be configurated by the second dict.
|
||||
Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
|
||||
divisor=6.0)).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
ratio=16,
|
||||
conv_cfg=None,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0))):
|
||||
super(SELayer, self).__init__()
|
||||
if isinstance(act_cfg, dict):
|
||||
act_cfg = (act_cfg, act_cfg)
|
||||
assert len(act_cfg) == 2
|
||||
assert mmcv.is_tuple_of(act_cfg, dict)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=channels,
|
||||
out_channels=make_divisible(channels // ratio, 8),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=act_cfg[0])
|
||||
self.conv2 = ConvModule(
|
||||
in_channels=make_divisible(channels // ratio, 8),
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=act_cfg[1])
|
||||
|
||||
def forward(self, x):
|
||||
out = self.global_avgpool(x)
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(out)
|
||||
return x * out
|
|
@ -4,8 +4,8 @@ from mmcv.ops import DeformConv2dPack
|
|||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from torch.nn.modules import AvgPool2d, GroupNorm
|
||||
|
||||
from mmseg.models.backbones import (CGNet, FastSCNN, ResNeSt, ResNet,
|
||||
ResNetV1d, ResNeXt)
|
||||
from mmseg.models.backbones import (CGNet, FastSCNN, MobileNetV3, ResNeSt,
|
||||
ResNet, ResNetV1d, ResNeXt)
|
||||
from mmseg.models.backbones.cgnet import (ContextGuidedBlock,
|
||||
GlobalContextExtractor)
|
||||
from mmseg.models.backbones.resnest import Bottleneck as BottleneckS
|
||||
|
@ -875,3 +875,65 @@ def test_cgnet_backbone():
|
|||
assert feat[0].shape == torch.Size([2, 35, 112, 112])
|
||||
assert feat[1].shape == torch.Size([2, 131, 56, 56])
|
||||
assert feat[2].shape == torch.Size([2, 256, 28, 28])
|
||||
|
||||
|
||||
def test_mobilenet_v3():
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid arch
|
||||
MobileNetV3('big')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid reduction_factor
|
||||
MobileNetV3(reduction_factor=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# check invalid out_indices
|
||||
MobileNetV3(out_indices=(0, 1, 15))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# check invalid frozen_stages
|
||||
MobileNetV3(frozen_stages=15)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# check invalid pretrained
|
||||
model = MobileNetV3()
|
||||
model.init_weights(pretrained=8)
|
||||
|
||||
# Test MobileNetV3 with default settings
|
||||
model = MobileNetV3()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == (2, 16, 112, 112)
|
||||
assert feat[1].shape == (2, 16, 56, 56)
|
||||
assert feat[2].shape == (2, 576, 28, 28)
|
||||
|
||||
# Test MobileNetV3 with arch = 'large'
|
||||
model = MobileNetV3(arch='large', out_indices=(1, 3, 16))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == (2, 16, 112, 112)
|
||||
assert feat[1].shape == (2, 24, 56, 56)
|
||||
assert feat[2].shape == (2, 960, 28, 28)
|
||||
|
||||
# Test MobileNetV3 with norm_eval True, with_cp True and frozen_stages=5
|
||||
model = MobileNetV3(norm_eval=True, with_cp=True, frozen_stages=5)
|
||||
with pytest.raises(TypeError):
|
||||
# check invalid pretrained
|
||||
model.init_weights(pretrained=8)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == (2, 16, 112, 112)
|
||||
assert feat[1].shape == (2, 16, 56, 56)
|
||||
assert feat[2].shape == (2, 576, 28, 28)
|
||||
|
|
|
@ -10,8 +10,8 @@ from mmseg.models.decode_heads import (ANNHead, APCHead, ASPPHead, CCHead,
|
|||
DAHead, DepthwiseSeparableASPPHead,
|
||||
DepthwiseSeparableFCNHead, DMHead,
|
||||
DNLHead, EMAHead, EncHead, FCNHead,
|
||||
GCHead, NLHead, OCRHead, PointHead,
|
||||
PSAHead, PSPHead, UPerHead)
|
||||
GCHead, LRASPPHead, NLHead, OCRHead,
|
||||
PointHead, PSAHead, PSPHead, UPerHead)
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -769,3 +769,66 @@ def test_point_head():
|
|||
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
|
||||
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
|
||||
assert output.shape == (1, point_head.num_classes, 180, 180)
|
||||
|
||||
|
||||
def test_lraspp_head():
|
||||
with pytest.raises(ValueError):
|
||||
# check invalid input_transform
|
||||
LRASPPHead(
|
||||
in_channels=(16, 16, 576),
|
||||
in_index=(0, 1, 2),
|
||||
channels=128,
|
||||
input_transform='resize_concat',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid branch_channels
|
||||
LRASPPHead(
|
||||
in_channels=(16, 16, 576),
|
||||
in_index=(0, 1, 2),
|
||||
channels=128,
|
||||
branch_channels=64,
|
||||
input_transform='multiple_select',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
|
||||
# test with default settings
|
||||
lraspp_head = LRASPPHead(
|
||||
in_channels=(16, 16, 576),
|
||||
in_index=(0, 1, 2),
|
||||
channels=128,
|
||||
input_transform='multiple_select',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
inputs = [
|
||||
torch.randn(2, 16, 45, 45),
|
||||
torch.randn(2, 16, 28, 28),
|
||||
torch.randn(2, 576, 14, 14)
|
||||
]
|
||||
with pytest.raises(RuntimeError):
|
||||
# check invalid inputs
|
||||
output = lraspp_head(inputs)
|
||||
|
||||
inputs = [
|
||||
torch.randn(2, 16, 111, 111),
|
||||
torch.randn(2, 16, 77, 77),
|
||||
torch.randn(2, 576, 55, 55)
|
||||
]
|
||||
output = lraspp_head(inputs)
|
||||
assert output.shape == (2, 19, 111, 111)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import mmcv
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.utils import InvertedResidual
|
||||
from mmseg.models.utils import InvertedResidual, InvertedResidualV3
|
||||
|
||||
|
||||
def test_inv_residual():
|
||||
|
@ -38,3 +39,82 @@ def test_inv_residual():
|
|||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 32, 64, 64)
|
||||
|
||||
# test with checkpoint forward
|
||||
inv_module = InvertedResidual(32, 32, 1, 1, with_cp=True)
|
||||
assert inv_module.with_cp
|
||||
x = torch.rand(1, 32, 64, 64, requires_grad=True)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 32, 64, 64)
|
||||
|
||||
|
||||
def test_inv_residualv3():
|
||||
with pytest.raises(AssertionError):
|
||||
# test stride assertion.
|
||||
InvertedResidualV3(32, 32, 16, stride=3)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test assertion.
|
||||
InvertedResidualV3(32, 32, 16, with_expand_conv=False)
|
||||
|
||||
# test with se_cfg=None, with_expand_conv=False
|
||||
inv_module = InvertedResidualV3(32, 32, 32, with_expand_conv=False)
|
||||
|
||||
assert inv_module.with_res_shortcut is True
|
||||
assert inv_module.with_se is False
|
||||
assert inv_module.with_expand_conv is False
|
||||
assert not hasattr(inv_module, 'expand_conv')
|
||||
assert isinstance(inv_module.depthwise_conv.conv, torch.nn.Conv2d)
|
||||
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
|
||||
assert inv_module.depthwise_conv.conv.stride == (1, 1)
|
||||
assert inv_module.depthwise_conv.conv.padding == (1, 1)
|
||||
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
||||
assert isinstance(inv_module.depthwise_conv.activate, torch.nn.ReLU)
|
||||
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
||||
assert inv_module.linear_conv.conv.stride == (1, 1)
|
||||
assert inv_module.linear_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
||||
|
||||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 32, 64, 64)
|
||||
|
||||
# test with se_cfg and with_expand_conv
|
||||
se_cfg = dict(
|
||||
channels=16,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
||||
act_cfg = dict(type='HSwish')
|
||||
inv_module = InvertedResidualV3(
|
||||
32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg)
|
||||
assert inv_module.with_res_shortcut is False
|
||||
assert inv_module.with_se is True
|
||||
assert inv_module.with_expand_conv is True
|
||||
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
|
||||
assert inv_module.expand_conv.conv.stride == (1, 1)
|
||||
assert inv_module.expand_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
|
||||
|
||||
assert isinstance(inv_module.depthwise_conv.conv,
|
||||
mmcv.cnn.bricks.Conv2dAdaptivePadding)
|
||||
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
|
||||
assert inv_module.depthwise_conv.conv.stride == (2, 2)
|
||||
assert inv_module.depthwise_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
||||
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
|
||||
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
||||
assert inv_module.linear_conv.conv.stride == (1, 1)
|
||||
assert inv_module.linear_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
||||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 40, 32, 32)
|
||||
|
||||
# test with checkpoint forward
|
||||
inv_module = InvertedResidualV3(
|
||||
32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg, with_cp=True)
|
||||
assert inv_module.with_cp
|
||||
x = torch.randn(2, 32, 64, 64, requires_grad=True)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (2, 40, 32, 32)
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
from mmseg.models.utils import make_divisible
|
||||
|
||||
|
||||
def test_make_divisible():
|
||||
# test with min_value = None
|
||||
assert make_divisible(10, 4) == 12
|
||||
assert make_divisible(9, 4) == 12
|
||||
assert make_divisible(1, 4) == 4
|
||||
|
||||
# test with min_value = 8
|
||||
assert make_divisible(10, 4, 8) == 12
|
||||
assert make_divisible(9, 4, 8) == 12
|
||||
assert make_divisible(1, 4, 8) == 8
|
|
@ -0,0 +1,41 @@
|
|||
import mmcv
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.utils.se_layer import SELayer
|
||||
|
||||
|
||||
def test_se_layer():
|
||||
with pytest.raises(AssertionError):
|
||||
# test act_cfg assertion.
|
||||
SELayer(32, act_cfg=(dict(type='ReLU'), ))
|
||||
|
||||
# test config with channels = 16.
|
||||
se_layer = SELayer(16)
|
||||
assert se_layer.conv1.conv.kernel_size == (1, 1)
|
||||
assert se_layer.conv1.conv.stride == (1, 1)
|
||||
assert se_layer.conv1.conv.padding == (0, 0)
|
||||
assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
|
||||
assert se_layer.conv2.conv.kernel_size == (1, 1)
|
||||
assert se_layer.conv2.conv.stride == (1, 1)
|
||||
assert se_layer.conv2.conv.padding == (0, 0)
|
||||
assert isinstance(se_layer.conv2.activate, mmcv.cnn.HSigmoid)
|
||||
|
||||
x = torch.rand(1, 16, 64, 64)
|
||||
output = se_layer(x)
|
||||
assert output.shape == (1, 16, 64, 64)
|
||||
|
||||
# test config with channels = 16, act_cfg = dict(type='ReLU').
|
||||
se_layer = SELayer(16, act_cfg=dict(type='ReLU'))
|
||||
assert se_layer.conv1.conv.kernel_size == (1, 1)
|
||||
assert se_layer.conv1.conv.stride == (1, 1)
|
||||
assert se_layer.conv1.conv.padding == (0, 0)
|
||||
assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
|
||||
assert se_layer.conv2.conv.kernel_size == (1, 1)
|
||||
assert se_layer.conv2.conv.stride == (1, 1)
|
||||
assert se_layer.conv2.conv.padding == (0, 0)
|
||||
assert isinstance(se_layer.conv2.activate, torch.nn.ReLU)
|
||||
|
||||
x = torch.rand(1, 16, 64, 64)
|
||||
output = se_layer(x)
|
||||
assert output.shape == (1, 16, 64, 64)
|
Loading…
Reference in New Issue