diff --git a/configs/fastscnn/README.md b/configs/fastscnn/README.md index dbfc0e427..3d8f778b9 100644 --- a/configs/fastscnn/README.md +++ b/configs/fastscnn/README.md @@ -15,4 +15,4 @@ ### Cityscapes | Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | |------------|-----------|-----------|--------:|----------|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Fast-SCNN | Fast-SCNN | 512x1024 | 80000 | 8.4 | 63.61 | 69.06 | - | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-cae6c46a.pth) | [log](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-20200807_165744.log.json) | +| Fast-SCNN | Fast-SCNN | 512x1024 | 80000 | 8.4 | 63.61 | 69.06 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-f5096c79.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-20200807_165744.log.json) | diff --git a/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_pascal.py b/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_pascal.py deleted file mode 100644 index 23c2ea996..000000000 --- a/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_pascal.py +++ /dev/null @@ -1,70 +0,0 @@ -_base_ = [ - '../_base_/models/fast_scnn.py', '../_base_/datasets/pascal_voc12.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' -] - -# Re-config the data sampler. -data = dict(samples_per_gpu=8, workers_per_gpu=4) - -# Re-config the optimizer. -optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5) - -# update num_classes of the segmentor. -# model settings -norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) -model = dict( - type='EncoderDecoder', - backbone=dict( - type='FastSCNN', - downsample_dw_channels=(32, 48), - global_in_channels=64, - global_block_channels=(64, 96, 128), - global_block_strides=(2, 2, 1), - global_out_channels=128, - higher_in_channels=64, - lower_in_channels=128, - fusion_out_channels=128, - out_indices=(0, 1, 2), - norm_cfg=norm_cfg, - align_corners=False), - decode_head=dict( - type='DepthwiseSeparableFCNHead', - in_channels=128, - channels=128, - concat_input=False, - num_classes=21, - in_index=-1, - norm_cfg=norm_cfg, - align_corners=False, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.)), - auxiliary_head=[ - dict( - type='FCNHead', - in_channels=128, - channels=32, - num_convs=1, - num_classes=21, - in_index=-2, - norm_cfg=norm_cfg, - concat_input=False, - align_corners=False, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), - dict( - type='FCNHead', - in_channels=64, - channels=32, - num_convs=1, - num_classes=21, - in_index=-3, - norm_cfg=norm_cfg, - concat_input=False, - align_corners=False, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), - ]) - -# model training and testing settings -train_cfg = dict() -test_cfg = dict(mode='whole') diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index 4aaec2212..ee115ffda 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -6,8 +6,8 @@ from torch.nn.modules.batchnorm import _BatchNorm from mmseg.models.decode_heads.psp_head import PPM from mmseg.ops import resize -from mmseg.utils import InvertedResidual from ..builder import BACKBONES +from ..utils.inverted_residual import InvertedResidual class LearningToDownsample(nn.Module): diff --git a/mmseg/models/backbones/mobilenet_v2.py b/mmseg/models/backbones/mobilenet_v2.py index 5fff485f0..5820b4b13 100644 --- a/mmseg/models/backbones/mobilenet_v2.py +++ b/mmseg/models/backbones/mobilenet_v2.py @@ -1,102 +1,12 @@ import logging import torch.nn as nn -import torch.utils.checkpoint as cp from mmcv.cnn import ConvModule, constant_init, kaiming_init from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from ..builder import BACKBONES -from ..utils import make_divisible - - -class InvertedResidual(nn.Module): - """InvertedResidual block for MobileNetV2. - - Args: - in_channels (int): The input channels of the InvertedResidual block. - out_channels (int): The output channels of the InvertedResidual block. - stride (int): Stride of the middle (first) 3x3 convolution. - expand_ratio (int): Adjusts number of channels of the hidden layer - in InvertedResidual by this amount. - dilation (int): Dilation rate of depthwise conv. Default: 1 - conv_cfg (dict): Config dict for convolution layer. - Default: None, which means using conv2d. - norm_cfg (dict): Config dict for normalization layer. - Default: dict(type='BN'). - act_cfg (dict): Config dict for activation layer. - Default: dict(type='ReLU6'). - with_cp (bool): Use checkpoint or not. Using checkpoint will save some - memory while slowing down the training speed. Default: False. - - Returns: - Tensor: The output tensor - """ - - def __init__(self, - in_channels, - out_channels, - stride, - expand_ratio, - dilation=1, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU6'), - with_cp=False): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2], f'stride must in [1, 2]. ' \ - f'But received {stride}.' - self.with_cp = with_cp - self.use_res_connect = self.stride == 1 and in_channels == out_channels - hidden_dim = int(round(in_channels * expand_ratio)) - - layers = [] - if expand_ratio != 1: - layers.append( - ConvModule( - in_channels=in_channels, - out_channels=hidden_dim, - kernel_size=1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg)) - layers.extend([ - ConvModule( - in_channels=hidden_dim, - out_channels=hidden_dim, - kernel_size=3, - stride=stride, - padding=dilation, - dilation=dilation, - groups=hidden_dim, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg), - ConvModule( - in_channels=hidden_dim, - out_channels=out_channels, - kernel_size=1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=None) - ]) - self.conv = nn.Sequential(*layers) - - def forward(self, x): - - def _inner_forward(x): - if self.use_res_connect: - return x + self.conv(x) - else: - return self.conv(x) - - if self.with_cp and x.requires_grad: - out = cp.checkpoint(_inner_forward, x) - else: - out = _inner_forward(x) - - return out +from ..utils import InvertedResidual, make_divisible @BACKBONES.register_module() diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index bea300c3a..969a0c7d9 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,5 +1,8 @@ +from .inverted_residual import InvertedResidual from .make_divisible import make_divisible from .res_layer import ResLayer from .self_attention_block import SelfAttentionBlock -__all__ = ['ResLayer', 'SelfAttentionBlock', 'make_divisible'] +__all__ = [ + 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual' +] diff --git a/mmseg/utils/inverted_residual_module.py b/mmseg/models/utils/inverted_residual.py similarity index 53% rename from mmseg/utils/inverted_residual_module.py rename to mmseg/models/utils/inverted_residual.py index ff33a3604..c3de83aa2 100644 --- a/mmseg/utils/inverted_residual_module.py +++ b/mmseg/models/utils/inverted_residual.py @@ -1,22 +1,29 @@ -from mmcv.cnn import ConvModule, build_norm_layer -from torch import nn +from mmcv.cnn import ConvModule +from torch import nn as nn +from torch.utils import checkpoint as cp class InvertedResidual(nn.Module): - """Inverted residual module. + """InvertedResidual block for MobileNetV2. Args: in_channels (int): The input channels of the InvertedResidual block. out_channels (int): The output channels of the InvertedResidual block. stride (int): Stride of the middle (first) 3x3 convolution. - expand_ratio (int): adjusts number of channels of the hidden layer + expand_ratio (int): Adjusts number of channels of the hidden layer in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor """ def __init__(self, @@ -27,47 +34,59 @@ class InvertedResidual(nn.Module): dilation=1, conv_cfg=None, norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU6')): + act_cfg=dict(type='ReLU6'), + with_cp=False): super(InvertedResidual, self).__init__() self.stride = stride - assert stride in [1, 2] - + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels hidden_dim = int(round(in_channels * expand_ratio)) - self.use_res_connect = self.stride == 1 \ - and in_channels == out_channels layers = [] if expand_ratio != 1: - # pw layers.append( ConvModule( - in_channels, - hidden_dim, + in_channels=in_channels, + out_channels=hidden_dim, kernel_size=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) layers.extend([ - # dw ConvModule( - hidden_dim, - hidden_dim, + in_channels=hidden_dim, + out_channels=hidden_dim, kernel_size=3, - padding=dilation, stride=stride, + padding=dilation, dilation=dilation, groups=hidden_dim, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), - # pw-linear - nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias=False), - build_norm_layer(norm_cfg, out_channels)[1], + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) ]) self.conv = nn.Sequential(*layers) def forward(self, x): - if self.use_res_connect: - return x + self.conv(x) + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) else: - return self.conv(x) + out = _inner_forward(x) + + return out diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index 5b5405956..ac489e2db 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -1,5 +1,4 @@ from .collect_env import collect_env -from .inverted_residual_module import InvertedResidual from .logger import get_root_logger -__all__ = ['get_root_logger', 'collect_env', 'InvertedResidual'] +__all__ = ['get_root_logger', 'collect_env'] diff --git a/tests/test_utils/test_inverted_residual_module.py b/tests/test_utils/test_inverted_residual_module.py index 827c10528..279dcf442 100644 --- a/tests/test_utils/test_inverted_residual_module.py +++ b/tests/test_utils/test_inverted_residual_module.py @@ -1,7 +1,7 @@ import pytest import torch -from mmseg.utils import InvertedResidual +from mmseg.models.utils import InvertedResidual def test_inv_residual():