John Zhu c1f46a69f4
Fast-SCNN implemented (#58)
* init commit: fast_scnn

* 247917iters

* 4x8_80k

* configs placed in configs_unify.  4x8_80k exp.running.

* mmseg/utils/collect_env.py modified to support Windows

* study on lr

* bug in configs_unify/***/cityscapes.py fixed.

* lr0.08_100k

* lr_power changed to 1.2

* log_config by_epoch set to False.

* lr1.2

* doc strings added

* add fast_scnn backbone  test

* 80k 0.08,0.12

* add 450k

* fast_scnn test: fix BN bug.

* Add different config files into configs/

* .gitignore recovered.

* configs_unify del

* .gitignore recovered.

* delete sub-optimal config files of fast-scnn

* Code style improved.

* add docstrings to component modules of fast-scnn

* relevant files modified according to Jerry's instructions

* relevant files modified according to Jerry's instructions

* lint problems fixed.

* fast_scnn config extremely simplified.

* InvertedResidual

* fixed padding problems

* add unit test for inverted_residual

* add unit test for inverted_residual: debug 0

* add unit test for inverted_residual: debug 1

* add unit test for inverted_residual: debug 2

* add unit test for inverted_residual: debug 3

* add unit test for sep_fcn_head: debug 0

* add unit test for sep_fcn_head: debug 1

* add unit test for sep_fcn_head: debug 2

* add unit test for sep_fcn_head: debug 3

* add unit test for sep_fcn_head: debug 4

* add unit test for sep_fcn_head: debug 5

* FastSCNN type(dwchannels) changed to tuple.

* t changed to expand_ratio.

* Spaces fixed.

* Update mmseg/models/backbones/fast_scnn.py

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>

* Update mmseg/models/decode_heads/sep_fcn_head.py

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>

* Update mmseg/models/decode_heads/sep_fcn_head.py

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>

* Docstrings fixed.

* Docstrings fixed.

* Inverted Residual kept coherent with mmcl.

* Inverted Residual kept coherent with mmcl. Debug 0

* _make_layer parameters renamed.

* final commit

* Arg scale_factor deleted.

* Expand_ratio docstrings updated.

* final commit

* Readme for Fast-SCNN added.

* model-zoo.md modified.

* fast_scnn README updated.

* Move InvertedResidual module into mmseg/utils.

* test_inverted_residual module corrected.

* test_inverted_residual.py moved.

* encoder_decoder modified to avoid bugs when running PSPNet.
getting_started.md bug fixed.

* Revert "encoder_decoder modified to avoid bugs when running PSPNet. "

This reverts commit dd0aadfb

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
2020-08-18 23:33:05 +08:00

71 lines
2.3 KiB
Python

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class FCNHead(BaseDecodeHead):
"""Fully Convolution Networks for Semantic Segmentation.
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
Args:
num_convs (int): Number of convs in the head. Default: 2.
kernel_size (int): The kernel size for convs in the head. Default: 3.
concat_input (bool): Whether concat the input and output of convs
before classification layer.
"""
def __init__(self,
num_convs=2,
kernel_size=3,
concat_input=True,
**kwargs):
assert num_convs > 0
self.num_convs = num_convs
self.concat_input = concat_input
self.kernel_size = kernel_size
super(FCNHead, self).__init__(**kwargs)
convs = []
convs.append(
ConvModule(
self.in_channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
for i in range(num_convs - 1):
convs.append(
ConvModule(
self.channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.convs = nn.Sequential(*convs)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs(x)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output