From e3f6f655d69b777341aec2fe8829871cc0beadcb Mon Sep 17 00:00:00 2001 From: Junjun2016 Date: Fri, 18 Dec 2020 15:23:45 +0800 Subject: [PATCH] Support APCNet (#299) * Support APCNet * code optimization * add apcnet configs * add benchmark * add readme and model zoo * fix doc --- README.md | 1 + configs/_base_/models/apcnet_r50-d8.py | 44 +++++ configs/apcnet/README.md | 37 ++++ .../apcnet_r101-d8_512x1024_40k_cityscapes.py | 2 + .../apcnet_r101-d8_512x1024_80k_cityscapes.py | 2 + .../apcnet_r101-d8_512x512_160k_ade20k.py | 2 + .../apcnet_r101-d8_512x512_80k_ade20k.py | 2 + .../apcnet_r101-d8_769x769_40k_cityscapes.py | 2 + .../apcnet_r101-d8_769x769_80k_cityscapes.py | 2 + .../apcnet_r50-d8_512x1024_40k_cityscapes.py | 4 + .../apcnet_r50-d8_512x1024_80k_cityscapes.py | 4 + .../apcnet_r50-d8_512x512_160k_ade20k.py | 7 + .../apcnet_r50-d8_512x512_80k_ade20k.py | 7 + .../apcnet_r50-d8_769x769_40k_cityscapes.py | 9 + .../apcnet_r50-d8_769x769_80k_cityscapes.py | 9 + docs/model_zoo.md | 4 + mmseg/models/decode_heads/__init__.py | 3 +- mmseg/models/decode_heads/apc_head.py | 158 ++++++++++++++++++ tests/test_models/test_heads.py | 57 ++++++- 19 files changed, 353 insertions(+), 3 deletions(-) create mode 100644 configs/_base_/models/apcnet_r50-d8.py create mode 100644 configs/apcnet/README.md create mode 100644 configs/apcnet/apcnet_r101-d8_512x1024_40k_cityscapes.py create mode 100644 configs/apcnet/apcnet_r101-d8_512x1024_80k_cityscapes.py create mode 100644 configs/apcnet/apcnet_r101-d8_512x512_160k_ade20k.py create mode 100644 configs/apcnet/apcnet_r101-d8_512x512_80k_ade20k.py create mode 100644 configs/apcnet/apcnet_r101-d8_769x769_40k_cityscapes.py create mode 100644 configs/apcnet/apcnet_r101-d8_769x769_80k_cityscapes.py create mode 100644 configs/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes.py create mode 100644 configs/apcnet/apcnet_r50-d8_512x1024_80k_cityscapes.py create mode 100644 configs/apcnet/apcnet_r50-d8_512x512_160k_ade20k.py create mode 100644 configs/apcnet/apcnet_r50-d8_512x512_80k_ade20k.py create mode 100644 configs/apcnet/apcnet_r50-d8_769x769_40k_cityscapes.py create mode 100644 configs/apcnet/apcnet_r50-d8_769x769_80k_cityscapes.py create mode 100644 mmseg/models/decode_heads/apc_head.py diff --git a/README.md b/README.md index 0ea731fb8..ba9184c90 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ Supported methods: - [x] [EncNet](configs/encnet) - [x] [CCNet](configs/ccnet) - [x] [DANet](configs/danet) +- [x] [APCNet](configs/apcnet) - [x] [GCNet](configs/gcnet) - [x] [ANN](configs/ann) - [x] [OCRNet](configs/ocrnet) diff --git a/configs/_base_/models/apcnet_r50-d8.py b/configs/_base_/models/apcnet_r50-d8.py new file mode 100644 index 000000000..451cbc419 --- /dev/null +++ b/configs/_base_/models/apcnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='APCHead', + in_channels=2048, + in_index=3, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))) +# model training and testing settings +train_cfg = dict() +test_cfg = dict(mode='whole') diff --git a/configs/apcnet/README.md b/configs/apcnet/README.md new file mode 100644 index 000000000..2dc55a379 --- /dev/null +++ b/configs/apcnet/README.md @@ -0,0 +1,37 @@ +# Adaptive Pyramid Context Network for Semantic Segmentation + +## Introduction + +```latex +@InProceedings{He_2019_CVPR, +author = {He, Junjun and Deng, Zhongying and Zhou, Lei and Wang, Yali and Qiao, Yu}, +title = {Adaptive Pyramid Context Network for Semantic Segmentation}, +booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, +month = {June}, +year = {2019} +} +``` + +## Results and models + +### Cityscapes + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| APCNet | R-50-D8 | 512x1024 | 40000 | 7.7 | 3.57 | 78.02 | 79.26 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes/apcnet_r50-d8_512x1024_40k_cityscapes_20201214_115717-5e88fa33.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes/apcnet_r50-d8_512x1024_40k_cityscapes-20201214_115717.log.json) | +| APCNet | R-101-D8 | 512x1024 | 40000 | 11.2 | 2.15 | 79.08 | 80.34 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x1024_40k_cityscapes/apcnet_r101-d8_512x1024_40k_cityscapes_20201214_115716-abc9d111.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x1024_40k_cityscapes/apcnet_r101-d8_512x1024_40k_cityscapes-20201214_115716.log.json) | +| APCNet | R-50-D8 | 769x769 | 40000 | 8.7 | 1.52 | 77.89 | 79.75 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_769x769_40k_cityscapes/apcnet_r50-d8_769x769_40k_cityscapes_20201214_115717-2a2628d7.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_769x769_40k_cityscapes/apcnet_r50-d8_769x769_40k_cityscapes-20201214_115717.log.json) | +| APCNet | R-101-D8 | 769x769 | 40000 | 12.7 | 1.03 | 77.96 | 79.24 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_769x769_40k_cityscapes/apcnet_r101-d8_769x769_40k_cityscapes_20201214_115718-b650de90.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_769x769_40k_cityscapes/apcnet_r101-d8_769x769_40k_cityscapes-20201214_115718.log.json) | +| APCNet | R-50-D8 | 512x1024 | 80000 | - | - | 78.96 | 79.94 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x1024_80k_cityscapes/apcnet_r50-d8_512x1024_80k_cityscapes_20201214_115716-987f51e3.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x1024_80k_cityscapes/apcnet_r50-d8_512x1024_80k_cityscapes-20201214_115716.log.json) | +| APCNet | R-101-D8 | 512x1024 | 80000 | - | - | 79.64 | 80.61 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x1024_80k_cityscapes/apcnet_r101-d8_512x1024_80k_cityscapes_20201214_115705-b1ff208a.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x1024_80k_cityscapes/apcnet_r101-d8_512x1024_80k_cityscapes-20201214_115705.log.json) | +| APCNet | R-50-D8 | 769x769 | 80000 | - | - | 78.79 | 80.35 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_769x769_80k_cityscapes/apcnet_r50-d8_769x769_80k_cityscapes_20201214_115718-7ea9fa12.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_769x769_80k_cityscapes/apcnet_r50-d8_769x769_80k_cityscapes-20201214_115718.log.json) | +| APCNet | R-101-D8 | 769x769 | 80000 | - | - | 78.45 | 79.91 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_769x769_80k_cityscapes/apcnet_r101-d8_769x769_80k_cityscapes_20201214_115716-a7fbc2ab.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_769x769_80k_cityscapes/apcnet_r101-d8_769x769_80k_cityscapes-20201214_115716.log.json) | + +### ADE20K + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| APCNet | R-50-D8 | 512x512 | 80000 | 10.1 | 19.61 | 42.20 | 43.30 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x512_80k_ade20k/apcnet_r50-d8_512x512_80k_ade20k_20201214_115705-a8626293.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x512_80k_ade20k/apcnet_r50-d8_512x512_80k_ade20k-20201214_115705.log.json) | +| APCNet | R-101-D8 | 512x512 | 80000 | 13.6 | 13.10 | 45.54 | 46.65 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x512_80k_ade20k/apcnet_r101-d8_512x512_80k_ade20k_20201214_115704-c656c3fb.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x512_80k_ade20k/apcnet_r101-d8_512x512_80k_ade20k-20201214_115704.log.json) | +| APCNet | R-50-D8 | 512x512 | 160000 | - | - | 43.40 | 43.94 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x512_160k_ade20k/apcnet_r50-d8_512x512_160k_ade20k_20201214_115706-25fb92c2.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x512_160k_ade20k/apcnet_r50-d8_512x512_160k_ade20k-20201214_115706.log.json) | +| APCNet | R-101-D8 | 512x512 | 160000 | - | - | 45.41 | 46.63 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x512_160k_ade20k/apcnet_r101-d8_512x512_160k_ade20k_20201214_115705-73f9a8d7.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x512_160k_ade20k/apcnet_r101-d8_512x512_160k_ade20k-20201214_115705.log.json) | diff --git a/configs/apcnet/apcnet_r101-d8_512x1024_40k_cityscapes.py b/configs/apcnet/apcnet_r101-d8_512x1024_40k_cityscapes.py new file mode 100644 index 000000000..1e1cec673 --- /dev/null +++ b/configs/apcnet/apcnet_r101-d8_512x1024_40k_cityscapes.py @@ -0,0 +1,2 @@ +_base_ = './apcnet_r50-d8_512x1024_40k_cityscapes.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/apcnet/apcnet_r101-d8_512x1024_80k_cityscapes.py b/configs/apcnet/apcnet_r101-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..04cb006ba --- /dev/null +++ b/configs/apcnet/apcnet_r101-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,2 @@ +_base_ = './apcnet_r50-d8_512x1024_80k_cityscapes.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/apcnet/apcnet_r101-d8_512x512_160k_ade20k.py b/configs/apcnet/apcnet_r101-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..1ce2279a0 --- /dev/null +++ b/configs/apcnet/apcnet_r101-d8_512x512_160k_ade20k.py @@ -0,0 +1,2 @@ +_base_ = './apcnet_r50-d8_512x512_160k_ade20k.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/apcnet/apcnet_r101-d8_512x512_80k_ade20k.py b/configs/apcnet/apcnet_r101-d8_512x512_80k_ade20k.py new file mode 100644 index 000000000..8f10b9840 --- /dev/null +++ b/configs/apcnet/apcnet_r101-d8_512x512_80k_ade20k.py @@ -0,0 +1,2 @@ +_base_ = './apcnet_r50-d8_512x512_80k_ade20k.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/apcnet/apcnet_r101-d8_769x769_40k_cityscapes.py b/configs/apcnet/apcnet_r101-d8_769x769_40k_cityscapes.py new file mode 100644 index 000000000..5c44ebcaf --- /dev/null +++ b/configs/apcnet/apcnet_r101-d8_769x769_40k_cityscapes.py @@ -0,0 +1,2 @@ +_base_ = './apcnet_r50-d8_769x769_40k_cityscapes.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/apcnet/apcnet_r101-d8_769x769_80k_cityscapes.py b/configs/apcnet/apcnet_r101-d8_769x769_80k_cityscapes.py new file mode 100644 index 000000000..616984575 --- /dev/null +++ b/configs/apcnet/apcnet_r101-d8_769x769_80k_cityscapes.py @@ -0,0 +1,2 @@ +_base_ = './apcnet_r50-d8_769x769_80k_cityscapes.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes.py b/configs/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes.py new file mode 100644 index 000000000..99c61a942 --- /dev/null +++ b/configs/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes.py @@ -0,0 +1,4 @@ +_base_ = [ + '../_base_/models/apcnet_r50-d8.py', '../_base_/datasets/cityscapes.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' +] diff --git a/configs/apcnet/apcnet_r50-d8_512x1024_80k_cityscapes.py b/configs/apcnet/apcnet_r50-d8_512x1024_80k_cityscapes.py new file mode 100644 index 000000000..62a0627ae --- /dev/null +++ b/configs/apcnet/apcnet_r50-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,4 @@ +_base_ = [ + '../_base_/models/apcnet_r50-d8.py', '../_base_/datasets/cityscapes.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] diff --git a/configs/apcnet/apcnet_r50-d8_512x512_160k_ade20k.py b/configs/apcnet/apcnet_r50-d8_512x512_160k_ade20k.py new file mode 100644 index 000000000..aa45e35d3 --- /dev/null +++ b/configs/apcnet/apcnet_r50-d8_512x512_160k_ade20k.py @@ -0,0 +1,7 @@ +_base_ = [ + '../_base_/models/apcnet_r50-d8.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +model = dict( + decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) +test_cfg = dict(mode='whole') diff --git a/configs/apcnet/apcnet_r50-d8_512x512_80k_ade20k.py b/configs/apcnet/apcnet_r50-d8_512x512_80k_ade20k.py new file mode 100644 index 000000000..6b40d1f7a --- /dev/null +++ b/configs/apcnet/apcnet_r50-d8_512x512_80k_ade20k.py @@ -0,0 +1,7 @@ +_base_ = [ + '../_base_/models/apcnet_r50-d8.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +model = dict( + decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) +test_cfg = dict(mode='whole') diff --git a/configs/apcnet/apcnet_r50-d8_769x769_40k_cityscapes.py b/configs/apcnet/apcnet_r50-d8_769x769_40k_cityscapes.py new file mode 100644 index 000000000..d0134e31e --- /dev/null +++ b/configs/apcnet/apcnet_r50-d8_769x769_40k_cityscapes.py @@ -0,0 +1,9 @@ +_base_ = [ + '../_base_/models/apcnet_r50-d8.py', + '../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_40k.py' +] +model = dict( + decode_head=dict(align_corners=True), + auxiliary_head=dict(align_corners=True)) +test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513)) diff --git a/configs/apcnet/apcnet_r50-d8_769x769_80k_cityscapes.py b/configs/apcnet/apcnet_r50-d8_769x769_80k_cityscapes.py new file mode 100644 index 000000000..1d863c4f1 --- /dev/null +++ b/configs/apcnet/apcnet_r50-d8_769x769_80k_cityscapes.py @@ -0,0 +1,9 @@ +_base_ = [ + '../_base_/models/apcnet_r50-d8.py', + '../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] +model = dict( + decode_head=dict(align_corners=True), + auxiliary_head=dict(align_corners=True)) +test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513)) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 0dd1b410b..c130baf6a 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -67,6 +67,10 @@ Please refer to [CCNet](https://github.com/open-mmlab/mmsegmentation/blob/master Please refer to [DANet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/danet) for details. +### APCNet + +Please refer to [APCNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/apcnet) for details. + ### HRNet Please refer to [HRNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet) for details. diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index 6f3217ec0..1ac8c1ae3 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -1,4 +1,5 @@ from .ann_head import ANNHead +from .apc_head import APCHead from .aspp_head import ASPPHead from .cc_head import CCHead from .da_head import DAHead @@ -21,5 +22,5 @@ __all__ = [ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', - 'PointHead' + 'PointHead', 'APCHead' ] diff --git a/mmseg/models/decode_heads/apc_head.py b/mmseg/models/decode_heads/apc_head.py new file mode 100644 index 000000000..b453db394 --- /dev/null +++ b/mmseg/models/decode_heads/apc_head.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class ACM(nn.Module): + """Adaptive Context Module used in APCNet. + + Args: + pool_scale (int): Pooling scale used in Adaptive Context + Module to extract region fetures. + fusion (bool): Add one conv to fuse residual feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super(ACM, self).__init__() + self.pool_scale = pool_scale + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.pooled_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.global_info = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) + + self.residual_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) + # [batch_size, channels, h, w] + x = self.input_redu_conv(x) + # [batch_size, channels, pool_scale, pool_scale] + pooled_x = self.pooled_redu_conv(pooled_x) + batch_size = x.size(0) + # [batch_size, pool_scale * pool_scale, channels] + pooled_x = pooled_x.view(batch_size, self.channels, + -1).permute(0, 2, 1).contiguous() + # [batch_size, h * w, pool_scale * pool_scale] + affinity_matrix = self.gla(x + resize( + self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) + ).permute(0, 2, 3, 1).reshape( + batch_size, -1, self.pool_scale**2) + affinity_matrix = F.sigmoid(affinity_matrix) + # [batch_size, h * w, channels] + z_out = torch.matmul(affinity_matrix, pooled_x) + # [batch_size, channels, h * w] + z_out = z_out.permute(0, 2, 1).contiguous() + # [batch_size, channels, h, w] + z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) + z_out = self.residual_conv(z_out) + z_out = F.relu(z_out + x) + if self.fusion: + z_out = self.fusion_conv(z_out) + + return z_out + + +@HEADS.register_module() +class APCHead(BaseDecodeHead): + """Adaptive Pyramid Context Network for Semantic Segmentation. + + This head is the implementation of + `APCNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Adaptive Context + Module. Default: (1, 2, 3, 6). + fusion (bool): Add one conv to fuse residual feature. + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): + super(APCHead, self).__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.fusion = fusion + acm_modules = [] + for pool_scale in self.pool_scales: + acm_modules.append( + ACM(pool_scale, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.acm_modules = nn.ModuleList(acm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + acm_outs = [x] + for acm_module in self.acm_modules: + acm_outs.append(acm_module(x)) + acm_outs = torch.cat(acm_outs, dim=1) + output = self.bottleneck(acm_outs) + output = self.cls_seg(output) + return output diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index acf290226..5a8ab7463 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -6,8 +6,8 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmcv.utils import ConfigDict from mmcv.utils.parrots_wrapper import SyncBatchNorm -from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead, - DepthwiseSeparableASPPHead, +from mmseg.models.decode_heads import (ANNHead, APCHead, ASPPHead, CCHead, + DAHead, DepthwiseSeparableASPPHead, DepthwiseSeparableFCNHead, DNLHead, EMAHead, EncHead, FCNHead, GCHead, NLHead, OCRHead, PointHead, PSAHead, @@ -223,6 +223,59 @@ def test_psp_head(): assert outputs.shape == (1, head.num_classes, 45, 45) +def test_apc_head(): + + with pytest.raises(AssertionError): + # pool_scales must be list|tuple + APCHead(in_channels=32, channels=16, num_classes=19, pool_scales=1) + + # test no norm_cfg + head = APCHead(in_channels=32, channels=16, num_classes=19) + assert not _conv_has_norm(head, sync_bn=False) + + # test with norm_cfg + head = APCHead( + in_channels=32, + channels=16, + num_classes=19, + norm_cfg=dict(type='SyncBN')) + assert _conv_has_norm(head, sync_bn=True) + + # fusion=True + inputs = [torch.randn(1, 32, 45, 45)] + head = APCHead( + in_channels=32, + channels=16, + num_classes=19, + pool_scales=(1, 2, 3), + fusion=True) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.fusion is True + assert head.acm_modules[0].pool_scale == 1 + assert head.acm_modules[1].pool_scale == 2 + assert head.acm_modules[2].pool_scale == 3 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # fusion=False + inputs = [torch.randn(1, 32, 45, 45)] + head = APCHead( + in_channels=32, + channels=16, + num_classes=19, + pool_scales=(1, 2, 3), + fusion=False) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.fusion is False + assert head.acm_modules[0].pool_scale == 1 + assert head.acm_modules[1].pool_scale == 2 + assert head.acm_modules[2].pool_scale == 3 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + def test_aspp_head(): with pytest.raises(AssertionError):