Add "disentangled non-local (DNL) neural networks" [ECCV2020] (#37)

* Add DNLHead

* add configs

* add weight decay mult

* add norm back

* Update README.md

* matched inference performance

* Fixed shape

* sep conv_out

* no norm

* add norm back

* complete model zoo

* add tests

* Add test forward

* Add more test

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
pull/1801/head
Han Hu 2020-09-07 16:22:00 +08:00 committed by GitHub
parent 4b883ab717
commit d7ae15c7f7
19 changed files with 324 additions and 4 deletions

View File

@ -72,6 +72,9 @@ Supported methods:
- [x] [ANN](configs/ann)
- [x] [OCRNet](configs/ocrnet)
- [x] [Fast-SCNN](configs/fastscnn)
- [x] [Semantic FPN](configs/sem_fpn)
- [x] [EMANet](configs/emanet)
- [x] [DNLNet](configs/dnlnet)
- [x] [Mixed Precision (FP16) Training](configs/fp16/README.md)
## Installation

View File

@ -0,0 +1,46 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DNLHead',
in_channels=2048,
in_index=3,
channels=512,
dropout_ratio=0.1,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
# model training and testing settings
train_cfg = dict()
test_cfg = dict(mode='whole')

View File

@ -0,0 +1,40 @@
# Disentangled Non-Local Neural Networks
## Introduction
This example is to reproduce ["Disentangled Non-Local Neural Networks"](https://arxiv.org/abs/2006.06668) for semantic segmentation. It is still in progress.
## Citation
```
@misc{yin2020disentangled,
title={Disentangled Non-Local Neural Networks},
author={Minghao Yin and Zhuliang Yao and Yue Cao and Xiu Li and Zheng Zhang and Stephen Lin and Han Hu},
year={2020},
booktitle={ECCV}
}
```
## Results and models (in progress)
### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| dnl | R-50-D8 | 512x1024 | 40000 | 7.3 | 2.56 | 78.61 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_40k_cityscapes/dnl_r50-d8_512x1024_40k_cityscapes_20200904_233629-53d4ea93.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_40k_cityscapes/dnl_r50-d8_512x1024_40k_cityscapes-20200904_233629.log.json) |
| dnl | R-101-D8 | 512x1024 | 40000 | 10.9 | 1.96 | 78.31 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_40k_cityscapes/dnl_r101-d8_512x1024_40k_cityscapes_20200904_233629-9928ffef.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_40k_cityscapes/dnl_r101-d8_512x1024_40k_cityscapes-20200904_233629.log.json) |
| dnl | R-50-D8 | 769x769 | 40000 | 9.2 | 1.50 | 78.44 | 80.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_40k_cityscapes/dnl_r50-d8_769x769_40k_cityscapes_20200820_232206-0f283785.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_40k_cityscapes/dnl_r50-d8_769x769_40k_cityscapes-20200820_232206.log.json) |
| dnl | R-101-D8 | 769x769 | 40000 | 12.6 | 1.02 | 76.39 | 77.77 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_40k_cityscapes/dnl_r101-d8_769x769_40k_cityscapes_20200820_171256-76c596df.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_40k_cityscapes/dnl_r101-d8_769x769_40k_cityscapes-20200820_171256.log.json) |
| dnl | R-50-D8 | 512x1024 | 80000 | - | - | 79.33 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_80k_cityscapes/dnl_r50-d8_512x1024_80k_cityscapes_20200904_233629-58b2f778.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_80k_cityscapes/dnl_r50-d8_512x1024_80k_cityscapes-20200904_233629.log.json) |
| dnl | R-101-D8 | 512x1024 | 80000 | - | - | 80.41 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_80k_cityscapes/dnl_r101-d8_512x1024_80k_cityscapes_20200904_233629-758e2dd4.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_80k_cityscapes/dnl_r101-d8_512x1024_80k_cityscapes-20200904_233629.log.json) |
| dnl | R-50-D8 | 769x769 | 80000 | - | - | 79.36 | 80.70 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_80k_cityscapes/dnl_r50-d8_769x769_80k_cityscapes_20200820_011925-366bc4c7.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_80k_cityscapes/dnl_r50-d8_769x769_80k_cityscapes-20200820_011925.log.json) |
| dnl | R-101-D8 | 769x769 | 80000 | - | - | 79.41 | 80.68 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_80k_cityscapes/dnl_r101-d8_769x769_80k_cityscapes_20200821_051111-95ff84ab.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_80k_cityscapes/dnl_r101-d8_769x769_80k_cityscapes-20200821_051111.log.json) |
### ADE20K
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| DNL | R-50-D8 | 512x512 | 80000 | 8.8 | 20.66 | 41.76 | 42.99 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_80k_ade20k/dnl_r50-d8_512x512_80k_ade20k_20200826_183354-1cf6e0c1.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_80k_ade20k/dnl_r50-d8_512x512_80k_ade20k-20200826_183354.log.json) |
| DNL | R-101-D8 | 512x512 | 80000 | 12.8 | 12.54 | 43.76 | 44.91 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_80k_ade20k/dnl_r101-d8_512x512_80k_ade20k_20200826_183354-d820d6ea.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_80k_ade20k/dnl_r101-d8_512x512_80k_ade20k-20200826_183354.log.json) |
| DNL | R-50-D8 | 512x512 | 160000 | - | - | 41.87 | 43.01 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_160k_ade20k/dnl_r50-d8_512x512_160k_ade20k_20200826_183350-37837798.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_160k_ade20k/dnl_r50-d8_512x512_160k_ade20k-20200826_183350.log.json) |
| DNL | R-101-D8 | 512x512 | 160000 | - | - | 44.25 | 45.78 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_160k_ade20k/dnl_r101-d8_512x512_160k_ade20k_20200826_183350-ed522c61.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_160k_ade20k/dnl_r101-d8_512x512_160k_ade20k-20200826_183350.log.json) |

View File

@ -0,0 +1,2 @@
_base_ = './dnl_r50-d8_512x1024_40k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,2 @@
_base_ = './dnl_r50-d8_512x1024_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,2 @@
_base_ = './dnl_r50-d8_512x512_160k_ade20k.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,2 @@
_base_ = './dnl_r50-d8_512x512_80k_ade20k.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,2 @@
_base_ = './dnl_r50-d8_769x769_40k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,2 @@
_base_ = './dnl_r50-d8_769x769_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]

View File

@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/dnl_r50-d8.py',
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_40k.py'
]
model = dict(
decode_head=dict(align_corners=True),
auxiliary_head=dict(align_corners=True))
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))

View File

@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/dnl_r50-d8.py',
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(align_corners=True),
auxiliary_head=dict(align_corners=True))
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
optimizer = dict(
paramwise_cfg=dict(
custom_keys=dict(theta=dict(wd_mult=0.), phi=dict(wd_mult=0.))))

View File

@ -2,6 +2,7 @@ from .ann_head import ANNHead
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
from .dnl_head import DNLHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
@ -18,5 +19,5 @@ from .uper_head import UPerHead
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead'
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead'
]

View File

@ -0,0 +1,131 @@
import torch
from mmcv.cnn import NonLocal2d
from torch import nn
from ..builder import HEADS
from .fcn_head import FCNHead
class DisentangledNonLocal2d(NonLocal2d):
"""Disentangled Non-Local Blocks.
Args:
temperature (float): Temperature to adjust attention. Default: 0.05
"""
def __init__(self, *arg, temperature, **kwargs):
super().__init__(*arg, **kwargs)
self.temperature = temperature
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
def embedded_gaussian(self, theta_x, phi_x):
"""Embedded gaussian with temperature."""
# NonLocal2d pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight /= self.temperature
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def forward(self, x):
# x: [N, C, H, W]
n = x.size(0)
# g_x: [N, HxW, C]
g_x = self.g(x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
if self.mode == 'gaussian':
theta_x = x.view(n, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(n, self.in_channels, -1)
else:
phi_x = x.view(n, self.in_channels, -1)
elif self.mode == 'concatenation':
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
else:
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(n, self.inter_channels, -1)
# subtract mean
theta_x -= theta_x.mean(dim=-2, keepdim=True)
phi_x -= phi_x.mean(dim=-1, keepdim=True)
pairwise_func = getattr(self, self.mode)
# pairwise_weight: [N, HxW, HxW]
pairwise_weight = pairwise_func(theta_x, phi_x)
# y: [N, HxW, C]
y = torch.matmul(pairwise_weight, g_x)
# y: [N, C, H, W]
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
*x.size()[2:])
# unary_mask: [N, 1, HxW]
unary_mask = self.conv_mask(x)
unary_mask = unary_mask.view(n, 1, -1)
unary_mask = unary_mask.softmax(dim=-1)
# unary_x: [N, 1, C]
unary_x = torch.matmul(unary_mask, g_x)
# unary_x: [N, C, 1, 1]
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
n, self.inter_channels, 1, 1)
output = x + self.conv_out(y + unary_x)
return output
@HEADS.register_module()
class DNLHead(FCNHead):
"""Disentangled Non-Local Neural Networks.
This head is the implementation of `DNLNet
<https://arxiv.org/abs/2006.06668>`_.
Args:
reduction (int): Reduction factor of projection transform. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
sqrt(1/inter_channels). Default: False.
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
'dot_product'. Default: 'embedded_gaussian.'.
temperature (float): Temperature to adjust attention. Default: 0.05
"""
def __init__(self,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
temperature=0.05,
**kwargs):
super(DNLHead, self).__init__(num_convs=2, **kwargs)
self.reduction = reduction
self.use_scale = use_scale
self.mode = mode
self.temperature = temperature
self.dnl_block = DisentangledNonLocal2d(
in_channels=self.channels,
reduction=self.reduction,
use_scale=self.use_scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
mode=self.mode,
temperature=self.temperature)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.dnl_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output

View File

@ -162,6 +162,11 @@ def test_mobilenet_v2_forward():
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')
def test_dnlnet_forward():
_test_encoder_decoder_forward(
'dnlnet/dnl_r50-d8_512x1024_40k_cityscapes.py')
def test_emanet_forward():
_test_encoder_decoder_forward(
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')

View File

@ -6,9 +6,10 @@ from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import SyncBatchNorm
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
DepthwiseSeparableASPPHead, EMAHead,
EncHead, FCNHead, GCHead, NLHead,
OCRHead, PSAHead, PSPHead, UPerHead)
DepthwiseSeparableASPPHead, DNLHead,
EMAHead, EncHead, FCNHead, GCHead,
NLHead, OCRHead, PSAHead, PSPHead,
UPerHead)
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
@ -541,6 +542,46 @@ def test_dw_aspp_head():
assert outputs.shape == (1, head.num_classes, 45, 45)
def test_dnl_head():
# DNL with 'embedded_gaussian' mode
head = DNLHead(in_channels=32, channels=16, num_classes=19)
assert len(head.convs) == 2
assert hasattr(head, 'dnl_block')
assert head.dnl_block.temperature == 0.05
inputs = [torch.randn(1, 32, 45, 45)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)
# NonLocal2d with 'dot_product' mode
head = DNLHead(
in_channels=32, channels=16, num_classes=19, mode='dot_product')
inputs = [torch.randn(1, 32, 45, 45)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)
# NonLocal2d with 'gaussian' mode
head = DNLHead(
in_channels=32, channels=16, num_classes=19, mode='gaussian')
inputs = [torch.randn(1, 32, 45, 45)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)
# NonLocal2d with 'concatenation' mode
head = DNLHead(
in_channels=32, channels=16, num_classes=19, mode='concatenation')
inputs = [torch.randn(1, 32, 45, 45)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)
def test_emanet_head():
head = EMAHead(
in_channels=32,