mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
parent
feefc6a9de
commit
94e12e8d21
@ -75,6 +75,7 @@ Supported methods:
|
||||
- [x] [DANet](configs/danet)
|
||||
- [x] [APCNet](configs/apcnet)
|
||||
- [x] [GCNet](configs/gcnet)
|
||||
- [x] [DMNet](configs/dmnet)
|
||||
- [x] [ANN](configs/ann)
|
||||
- [x] [OCRNet](configs/ocrnet)
|
||||
- [x] [Fast-SCNN](configs/fastscnn)
|
||||
|
44
configs/_base_/models/dmnet_r50-d8.py
Normal file
44
configs/_base_/models/dmnet_r50-d8.py
Normal file
@ -0,0 +1,44 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='DMHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
filter_sizes=(1, 3, 5, 7),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
|
||||
# model training and testing settings
|
||||
train_cfg = dict()
|
||||
test_cfg = dict(mode='whole')
|
37
configs/dmnet/README.md
Normal file
37
configs/dmnet/README.md
Normal file
@ -0,0 +1,37 @@
|
||||
# Dynamic Multi-scale Filters for Semantic Segmentation
|
||||
|
||||
## Introduction
|
||||
|
||||
```latex
|
||||
@InProceedings{He_2019_ICCV,
|
||||
author = {He, Junjun and Deng, Zhongying and Qiao, Yu},
|
||||
title = {Dynamic Multi-Scale Filters for Semantic Segmentation},
|
||||
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
||||
month = {October},
|
||||
year = {2019}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| DMNet | R-50-D8 | 512x1024 | 40000 | 7.0 | 3.66 | 77.78 | 79.14 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_512x1024_40k_cityscapes/dmnet_r50-d8_512x1024_40k_cityscapes_20201214_115717-5e88fa33.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_512x1024_40k_cityscapes/dmnet_r50-d8_512x1024_40k_cityscapes-20201214_115717.log.json) |
|
||||
| DMNet | R-101-D8 | 512x1024 | 40000 | 10.6 | 2.54 | 78.37 | 79.72 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_512x1024_40k_cityscapes/dmnet_r101-d8_512x1024_40k_cityscapes_20201214_115716-abc9d111.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_512x1024_40k_cityscapes/dmnet_r101-d8_512x1024_40k_cityscapes-20201214_115716.log.json) |
|
||||
| DMNet | R-50-D8 | 769x769 | 40000 | 7.9 | 1.57 | 78.49 | 80.27 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_769x769_40k_cityscapes/dmnet_r50-d8_769x769_40k_cityscapes_20201214_115717-2a2628d7.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_769x769_40k_cityscapes/dmnet_r50-d8_769x769_40k_cityscapes-20201214_115717.log.json) |
|
||||
| DMNet | R-101-D8 | 769x769 | 40000 | 12.0 | 1.01 | 77.62 | 78.94 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_769x769_40k_cityscapes/dmnet_r101-d8_769x769_40k_cityscapes_20201214_115718-b650de90.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_769x769_40k_cityscapes/dmnet_r101-d8_769x769_40k_cityscapes-20201214_115718.log.json) |
|
||||
| DMNet | R-50-D8 | 512x1024 | 80000 | - | - | 79.07 | 80.22 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_512x1024_80k_cityscapes/dmnet_r50-d8_512x1024_80k_cityscapes_20201214_115716-987f51e3.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_512x1024_80k_cityscapes/dmnet_r50-d8_512x1024_80k_cityscapes-20201214_115716.log.json) |
|
||||
| DMNet | R-101-D8 | 512x1024 | 80000 | - | - | 79.64 | 80.67 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_512x1024_80k_cityscapes/dmnet_r101-d8_512x1024_80k_cityscapes_20201214_115705-b1ff208a.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_512x1024_80k_cityscapes/dmnet_r101-d8_512x1024_80k_cityscapes-20201214_115705.log.json) |
|
||||
| DMNet | R-50-D8 | 769x769 | 80000 | - | - | 79.22 | 80.55 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_769x769_80k_cityscapes/dmnet_r50-d8_769x769_80k_cityscapes_20201214_115718-7ea9fa12.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_769x769_80k_cityscapes/dmnet_r50-d8_769x769_80k_cityscapes-20201214_115718.log.json) |
|
||||
| DMNet | R-101-D8 | 769x769 | 80000 | - | - | 79.19 | 80.65 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_769x769_80k_cityscapes/dmnet_r101-d8_769x769_80k_cityscapes_20201214_115716-a7fbc2ab.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_769x769_80k_cityscapes/dmnet_r101-d8_769x769_80k_cityscapes-20201214_115716.log.json) |
|
||||
|
||||
### ADE20K
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| DMNet | R-50-D8 | 512x512 | 80000 | 9.4 | 20.95 | 42.37 | 43.62 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_512x512_80k_ade20k/dmnet_r50-d8_512x512_80k_ade20k_20201214_115705-a8626293.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_512x512_80k_ade20k/dmnet_r50-d8_512x512_80k_ade20k-20201214_115705.log.json) |
|
||||
| DMNet | R-101-D8 | 512x512 | 80000 | 13.0 | 13.88 | 45.34 | 46.13 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_512x512_80k_ade20k/dmnet_r101-d8_512x512_80k_ade20k_20201214_115704-c656c3fb.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_512x512_80k_ade20k/dmnet_r101-d8_512x512_80k_ade20k-20201214_115704.log.json) |
|
||||
| DMNet | R-50-D8 | 512x512 | 160000 | - | - | 43.15 | 44.17 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_512x512_160k_ade20k/dmnet_r50-d8_512x512_160k_ade20k_20201214_115706-25fb92c2.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r50-d8_512x512_160k_ade20k/dmnet_r50-d8_512x512_160k_ade20k-20201214_115706.log.json) |
|
||||
| DMNet | R-101-D8 | 512x512 | 160000 | - | - | 45.42 | 46.76 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_512x512_160k_ade20k/dmnet_r101-d8_512x512_160k_ade20k_20201214_115705-73f9a8d7.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dmnet/dmnet_r101-d8_512x512_160k_ade20k/dmnet_r101-d8_512x512_160k_ade20k-20201214_115705.log.json) |
|
2
configs/dmnet/dmnet_r101-d8_512x1024_40k_cityscapes.py
Normal file
2
configs/dmnet/dmnet_r101-d8_512x1024_40k_cityscapes.py
Normal file
@ -0,0 +1,2 @@
|
||||
_base_ = './dmnet_r50-d8_512x1024_40k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
2
configs/dmnet/dmnet_r101-d8_512x1024_80k_cityscapes.py
Normal file
2
configs/dmnet/dmnet_r101-d8_512x1024_80k_cityscapes.py
Normal file
@ -0,0 +1,2 @@
|
||||
_base_ = './dmnet_r50-d8_512x1024_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
2
configs/dmnet/dmnet_r101-d8_512x512_160k_ade20k.py
Normal file
2
configs/dmnet/dmnet_r101-d8_512x512_160k_ade20k.py
Normal file
@ -0,0 +1,2 @@
|
||||
_base_ = './dmnet_r50-d8_512x512_160k_ade20k.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
2
configs/dmnet/dmnet_r101-d8_512x512_80k_ade20k.py
Normal file
2
configs/dmnet/dmnet_r101-d8_512x512_80k_ade20k.py
Normal file
@ -0,0 +1,2 @@
|
||||
_base_ = './dmnet_r50-d8_512x512_80k_ade20k.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
2
configs/dmnet/dmnet_r101-d8_769x769_40k_cityscapes.py
Normal file
2
configs/dmnet/dmnet_r101-d8_769x769_40k_cityscapes.py
Normal file
@ -0,0 +1,2 @@
|
||||
_base_ = './dmnet_r50-d8_769x769_40k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
2
configs/dmnet/dmnet_r101-d8_769x769_80k_cityscapes.py
Normal file
2
configs/dmnet/dmnet_r101-d8_769x769_80k_cityscapes.py
Normal file
@ -0,0 +1,2 @@
|
||||
_base_ = './dmnet_r50-d8_769x769_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
4
configs/dmnet/dmnet_r50-d8_512x1024_40k_cityscapes.py
Normal file
4
configs/dmnet/dmnet_r50-d8_512x1024_40k_cityscapes.py
Normal file
@ -0,0 +1,4 @@
|
||||
_base_ = [
|
||||
'../_base_/models/dmnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
|
||||
]
|
4
configs/dmnet/dmnet_r50-d8_512x1024_80k_cityscapes.py
Normal file
4
configs/dmnet/dmnet_r50-d8_512x1024_80k_cityscapes.py
Normal file
@ -0,0 +1,4 @@
|
||||
_base_ = [
|
||||
'../_base_/models/dmnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||
]
|
7
configs/dmnet/dmnet_r50-d8_512x512_160k_ade20k.py
Normal file
7
configs/dmnet/dmnet_r50-d8_512x512_160k_ade20k.py
Normal file
@ -0,0 +1,7 @@
|
||||
_base_ = [
|
||||
'../_base_/models/dmnet_r50-d8.py', '../_base_/datasets/ade20k.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
|
||||
test_cfg = dict(mode='whole')
|
7
configs/dmnet/dmnet_r50-d8_512x512_80k_ade20k.py
Normal file
7
configs/dmnet/dmnet_r50-d8_512x512_80k_ade20k.py
Normal file
@ -0,0 +1,7 @@
|
||||
_base_ = [
|
||||
'../_base_/models/dmnet_r50-d8.py', '../_base_/datasets/ade20k.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
|
||||
test_cfg = dict(mode='whole')
|
9
configs/dmnet/dmnet_r50-d8_769x769_40k_cityscapes.py
Normal file
9
configs/dmnet/dmnet_r50-d8_769x769_40k_cityscapes.py
Normal file
@ -0,0 +1,9 @@
|
||||
_base_ = [
|
||||
'../_base_/models/dmnet_r50-d8.py',
|
||||
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_40k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(align_corners=True),
|
||||
auxiliary_head=dict(align_corners=True))
|
||||
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
|
9
configs/dmnet/dmnet_r50-d8_769x769_80k_cityscapes.py
Normal file
9
configs/dmnet/dmnet_r50-d8_769x769_80k_cityscapes.py
Normal file
@ -0,0 +1,9 @@
|
||||
_base_ = [
|
||||
'../_base_/models/dmnet_r50-d8.py',
|
||||
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(align_corners=True),
|
||||
auxiliary_head=dict(align_corners=True))
|
||||
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
|
@ -79,6 +79,10 @@ Please refer to [HRNet](https://github.com/open-mmlab/mmsegmentation/blob/master
|
||||
|
||||
Please refer to [GCNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/gcnet) for details.
|
||||
|
||||
### DMNet
|
||||
|
||||
Please refer to [DMNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dmnet) for details.
|
||||
|
||||
### ANN
|
||||
|
||||
Please refer to [ANN](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/ann) for details.
|
||||
|
@ -3,6 +3,7 @@ from .apc_head import APCHead
|
||||
from .aspp_head import ASPPHead
|
||||
from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
from .dm_head import DMHead
|
||||
from .dnl_head import DNLHead
|
||||
from .ema_head import EMAHead
|
||||
from .enc_head import EncHead
|
||||
@ -22,5 +23,5 @@ __all__ = [
|
||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead', 'APCHead'
|
||||
'PointHead', 'APCHead', 'DMHead'
|
||||
]
|
||||
|
140
mmseg/models/decode_heads/dm_head.py
Normal file
140
mmseg/models/decode_heads/dm_head.py
Normal file
@ -0,0 +1,140 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class DCM(nn.Module):
|
||||
"""Dynamic Convolutional Module used in DMNet.
|
||||
|
||||
Args:
|
||||
filter_size (int): The filter size of generated convolution kernel
|
||||
used in Dynamic Convolutional Module.
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super(DCM, self).__init__()
|
||||
self.filter_size = filter_size
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
|
||||
0)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.norm_cfg is not None:
|
||||
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
self.activate = build_activation_layer(self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
generted_filter = self.filter_gen_conv(
|
||||
F.adaptive_avg_pool2d(x, self.filter_size))
|
||||
x = self.input_redu_conv(x)
|
||||
b, c, h, w = x.shape
|
||||
# [1, b * c, h, w], c = self.channels
|
||||
x = x.view(1, b * c, h, w)
|
||||
# [b * c, 1, filter_size, filter_size]
|
||||
generted_filter = generted_filter.view(b * c, 1, self.filter_size,
|
||||
self.filter_size)
|
||||
pad = (self.filter_size - 1) // 2
|
||||
if (self.filter_size - 1) % 2 == 0:
|
||||
p2d = (pad, pad, pad, pad)
|
||||
else:
|
||||
p2d = (pad + 1, pad, pad + 1, pad)
|
||||
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
|
||||
# [1, b * c, h, w]
|
||||
output = F.conv2d(input=x, weight=generted_filter, groups=b * c)
|
||||
# [b, c, h, w]
|
||||
output = output.view(b, c, h, w)
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
output = self.activate(output)
|
||||
|
||||
if self.fusion:
|
||||
output = self.fusion_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DMHead(BaseDecodeHead):
|
||||
"""Dynamic Multi-scale Filters for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
|
||||
He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
|
||||
ICCV_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
filter_sizes (tuple[int]): The size of generated convolutional filters
|
||||
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
|
||||
super(DMHead, self).__init__(**kwargs)
|
||||
assert isinstance(filter_sizes, (list, tuple))
|
||||
self.filter_sizes = filter_sizes
|
||||
self.fusion = fusion
|
||||
dcm_modules = []
|
||||
for filter_size in self.filter_sizes:
|
||||
dcm_modules.append(
|
||||
DCM(filter_size,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.dcm_modules = nn.ModuleList(dcm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(filter_sizes) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
dcm_outs = [x]
|
||||
for dcm_module in self.dcm_modules:
|
||||
dcm_outs.append(dcm_module(x))
|
||||
dcm_outs = torch.cat(dcm_outs, dim=1)
|
||||
output = self.bottleneck(dcm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
@ -8,10 +8,10 @@ from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
||||
|
||||
from mmseg.models.decode_heads import (ANNHead, APCHead, ASPPHead, CCHead,
|
||||
DAHead, DepthwiseSeparableASPPHead,
|
||||
DepthwiseSeparableFCNHead, DNLHead,
|
||||
EMAHead, EncHead, FCNHead, GCHead,
|
||||
NLHead, OCRHead, PointHead, PSAHead,
|
||||
PSPHead, UPerHead)
|
||||
DepthwiseSeparableFCNHead, DMHead,
|
||||
DNLHead, EMAHead, EncHead, FCNHead,
|
||||
GCHead, NLHead, OCRHead, PointHead,
|
||||
PSAHead, PSPHead, UPerHead)
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@ -276,6 +276,59 @@ def test_apc_head():
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_dm_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# filter_sizes must be list|tuple
|
||||
DMHead(in_channels=32, channels=16, num_classes=19, filter_sizes=1)
|
||||
|
||||
# test no norm_cfg
|
||||
head = DMHead(in_channels=32, channels=16, num_classes=19)
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
# test with norm_cfg
|
||||
head = DMHead(
|
||||
in_channels=32,
|
||||
channels=16,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN'))
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
|
||||
# fusion=True
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
head = DMHead(
|
||||
in_channels=32,
|
||||
channels=16,
|
||||
num_classes=19,
|
||||
filter_sizes=(1, 3, 5),
|
||||
fusion=True)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.fusion is True
|
||||
assert head.dcm_modules[0].filter_size == 1
|
||||
assert head.dcm_modules[1].filter_size == 3
|
||||
assert head.dcm_modules[2].filter_size == 5
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
# fusion=False
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
head = DMHead(
|
||||
in_channels=32,
|
||||
channels=16,
|
||||
num_classes=19,
|
||||
filter_sizes=(1, 3, 5),
|
||||
fusion=False)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.fusion is False
|
||||
assert head.dcm_modules[0].filter_size == 1
|
||||
assert head.dcm_modules[1].filter_size == 3
|
||||
assert head.dcm_modules[2].filter_size == 5
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_aspp_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
|
Loading…
x
Reference in New Issue
Block a user