Add "disentangled non-local (DNL) neural networks" [ECCV2020] (#37)
* Add DNLHead * add configs * add weight decay mult * add norm back * Update README.md * matched inference performance * Fixed shape * sep conv_out * no norm * add norm back * complete model zoo * add tests * Add test forward * Add more test Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>pull/1801/head
parent
4b883ab717
commit
d7ae15c7f7
|
@ -72,6 +72,9 @@ Supported methods:
|
|||
- [x] [ANN](configs/ann)
|
||||
- [x] [OCRNet](configs/ocrnet)
|
||||
- [x] [Fast-SCNN](configs/fastscnn)
|
||||
- [x] [Semantic FPN](configs/sem_fpn)
|
||||
- [x] [EMANet](configs/emanet)
|
||||
- [x] [DNLNet](configs/dnlnet)
|
||||
- [x] [Mixed Precision (FP16) Training](configs/fp16/README.md)
|
||||
|
||||
## Installation
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='DNLHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
|
||||
# model training and testing settings
|
||||
train_cfg = dict()
|
||||
test_cfg = dict(mode='whole')
|
|
@ -0,0 +1,40 @@
|
|||
# Disentangled Non-Local Neural Networks
|
||||
|
||||
## Introduction
|
||||
|
||||
This example is to reproduce ["Disentangled Non-Local Neural Networks"](https://arxiv.org/abs/2006.06668) for semantic segmentation. It is still in progress.
|
||||
|
||||
## Citation
|
||||
```
|
||||
@misc{yin2020disentangled,
|
||||
title={Disentangled Non-Local Neural Networks},
|
||||
author={Minghao Yin and Zhuliang Yao and Yue Cao and Xiu Li and Zheng Zhang and Stephen Lin and Han Hu},
|
||||
year={2020},
|
||||
booktitle={ECCV}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models (in progress)
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| dnl | R-50-D8 | 512x1024 | 40000 | 7.3 | 2.56 | 78.61 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_40k_cityscapes/dnl_r50-d8_512x1024_40k_cityscapes_20200904_233629-53d4ea93.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_40k_cityscapes/dnl_r50-d8_512x1024_40k_cityscapes-20200904_233629.log.json) |
|
||||
| dnl | R-101-D8 | 512x1024 | 40000 | 10.9 | 1.96 | 78.31 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_40k_cityscapes/dnl_r101-d8_512x1024_40k_cityscapes_20200904_233629-9928ffef.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_40k_cityscapes/dnl_r101-d8_512x1024_40k_cityscapes-20200904_233629.log.json) |
|
||||
| dnl | R-50-D8 | 769x769 | 40000 | 9.2 | 1.50 | 78.44 | 80.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_40k_cityscapes/dnl_r50-d8_769x769_40k_cityscapes_20200820_232206-0f283785.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_40k_cityscapes/dnl_r50-d8_769x769_40k_cityscapes-20200820_232206.log.json) |
|
||||
| dnl | R-101-D8 | 769x769 | 40000 | 12.6 | 1.02 | 76.39 | 77.77 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_40k_cityscapes/dnl_r101-d8_769x769_40k_cityscapes_20200820_171256-76c596df.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_40k_cityscapes/dnl_r101-d8_769x769_40k_cityscapes-20200820_171256.log.json) |
|
||||
| dnl | R-50-D8 | 512x1024 | 80000 | - | - | 79.33 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_80k_cityscapes/dnl_r50-d8_512x1024_80k_cityscapes_20200904_233629-58b2f778.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_80k_cityscapes/dnl_r50-d8_512x1024_80k_cityscapes-20200904_233629.log.json) |
|
||||
| dnl | R-101-D8 | 512x1024 | 80000 | - | - | 80.41 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_80k_cityscapes/dnl_r101-d8_512x1024_80k_cityscapes_20200904_233629-758e2dd4.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_80k_cityscapes/dnl_r101-d8_512x1024_80k_cityscapes-20200904_233629.log.json) |
|
||||
| dnl | R-50-D8 | 769x769 | 80000 | - | - | 79.36 | 80.70 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_80k_cityscapes/dnl_r50-d8_769x769_80k_cityscapes_20200820_011925-366bc4c7.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_80k_cityscapes/dnl_r50-d8_769x769_80k_cityscapes-20200820_011925.log.json) |
|
||||
| dnl | R-101-D8 | 769x769 | 80000 | - | - | 79.41 | 80.68 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_80k_cityscapes/dnl_r101-d8_769x769_80k_cityscapes_20200821_051111-95ff84ab.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_80k_cityscapes/dnl_r101-d8_769x769_80k_cityscapes-20200821_051111.log.json) |
|
||||
|
||||
|
||||
### ADE20K
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| DNL | R-50-D8 | 512x512 | 80000 | 8.8 | 20.66 | 41.76 | 42.99 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_80k_ade20k/dnl_r50-d8_512x512_80k_ade20k_20200826_183354-1cf6e0c1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_80k_ade20k/dnl_r50-d8_512x512_80k_ade20k-20200826_183354.log.json) |
|
||||
| DNL | R-101-D8 | 512x512 | 80000 | 12.8 | 12.54 | 43.76 | 44.91 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_80k_ade20k/dnl_r101-d8_512x512_80k_ade20k_20200826_183354-d820d6ea.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_80k_ade20k/dnl_r101-d8_512x512_80k_ade20k-20200826_183354.log.json) |
|
||||
| DNL | R-50-D8 | 512x512 | 160000 | - | - | 41.87 | 43.01 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_160k_ade20k/dnl_r50-d8_512x512_160k_ade20k_20200826_183350-37837798.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_160k_ade20k/dnl_r50-d8_512x512_160k_ade20k-20200826_183350.log.json) |
|
||||
| DNL | R-101-D8 | 512x512 | 160000 | - | - | 44.25 | 45.78 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_160k_ade20k/dnl_r101-d8_512x512_160k_ade20k_20200826_183350-ed522c61.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_160k_ade20k/dnl_r101-d8_512x512_160k_ade20k-20200826_183350.log.json) |
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './dnl_r50-d8_512x1024_40k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './dnl_r50-d8_512x1024_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './dnl_r50-d8_512x512_160k_ade20k.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './dnl_r50-d8_512x512_80k_ade20k.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './dnl_r50-d8_769x769_40k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './dnl_r50-d8_769x769_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
|
||||
]
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = [
|
||||
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/ade20k.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = [
|
||||
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/ade20k.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/dnl_r50-d8.py',
|
||||
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_40k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(align_corners=True),
|
||||
auxiliary_head=dict(align_corners=True))
|
||||
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/dnl_r50-d8.py',
|
||||
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(align_corners=True),
|
||||
auxiliary_head=dict(align_corners=True))
|
||||
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
|
||||
optimizer = dict(
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(theta=dict(wd_mult=0.), phi=dict(wd_mult=0.))))
|
|
@ -2,6 +2,7 @@ from .ann_head import ANNHead
|
|||
from .aspp_head import ASPPHead
|
||||
from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
from .dnl_head import DNLHead
|
||||
from .ema_head import EMAHead
|
||||
from .enc_head import EncHead
|
||||
from .fcn_head import FCNHead
|
||||
|
@ -18,5 +19,5 @@ from .uper_head import UPerHead
|
|||
__all__ = [
|
||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead'
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
from torch import nn
|
||||
|
||||
from ..builder import HEADS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
class DisentangledNonLocal2d(NonLocal2d):
|
||||
"""Disentangled Non-Local Blocks.
|
||||
|
||||
Args:
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self, *arg, temperature, **kwargs):
|
||||
super().__init__(*arg, **kwargs)
|
||||
self.temperature = temperature
|
||||
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
|
||||
|
||||
def embedded_gaussian(self, theta_x, phi_x):
|
||||
"""Embedded gaussian with temperature."""
|
||||
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||
if self.use_scale:
|
||||
# theta_x.shape[-1] is `self.inter_channels`
|
||||
pairwise_weight /= theta_x.shape[-1]**0.5
|
||||
pairwise_weight /= self.temperature
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, C, H, W]
|
||||
n = x.size(0)
|
||||
|
||||
# g_x: [N, HxW, C]
|
||||
g_x = self.g(x).view(n, self.inter_channels, -1)
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
||||
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
||||
if self.mode == 'gaussian':
|
||||
theta_x = x.view(n, self.in_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
if self.sub_sample:
|
||||
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
||||
else:
|
||||
phi_x = x.view(n, self.in_channels, -1)
|
||||
elif self.mode == 'concatenation':
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
||||
else:
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
||||
|
||||
# subtract mean
|
||||
theta_x -= theta_x.mean(dim=-2, keepdim=True)
|
||||
phi_x -= phi_x.mean(dim=-1, keepdim=True)
|
||||
|
||||
pairwise_func = getattr(self, self.mode)
|
||||
# pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = pairwise_func(theta_x, phi_x)
|
||||
|
||||
# y: [N, HxW, C]
|
||||
y = torch.matmul(pairwise_weight, g_x)
|
||||
# y: [N, C, H, W]
|
||||
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
||||
*x.size()[2:])
|
||||
|
||||
# unary_mask: [N, 1, HxW]
|
||||
unary_mask = self.conv_mask(x)
|
||||
unary_mask = unary_mask.view(n, 1, -1)
|
||||
unary_mask = unary_mask.softmax(dim=-1)
|
||||
# unary_x: [N, 1, C]
|
||||
unary_x = torch.matmul(unary_mask, g_x)
|
||||
# unary_x: [N, C, 1, 1]
|
||||
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
|
||||
n, self.inter_channels, 1, 1)
|
||||
|
||||
output = x + self.conv_out(y + unary_x)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DNLHead(FCNHead):
|
||||
"""Disentangled Non-Local Neural Networks.
|
||||
|
||||
This head is the implementation of `DNLNet
|
||||
<https://arxiv.org/abs/2006.06668>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: False.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
temperature=0.05,
|
||||
**kwargs):
|
||||
super(DNLHead, self).__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.temperature = temperature
|
||||
self.dnl_block = DisentangledNonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode,
|
||||
temperature=self.temperature)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.dnl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
|
@ -162,6 +162,11 @@ def test_mobilenet_v2_forward():
|
|||
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')
|
||||
|
||||
|
||||
def test_dnlnet_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'dnlnet/dnl_r50-d8_512x1024_40k_cityscapes.py')
|
||||
|
||||
|
||||
def test_emanet_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')
|
||||
|
|
|
@ -6,9 +6,10 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
||||
|
||||
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
||||
DepthwiseSeparableASPPHead, EMAHead,
|
||||
EncHead, FCNHead, GCHead, NLHead,
|
||||
OCRHead, PSAHead, PSPHead, UPerHead)
|
||||
DepthwiseSeparableASPPHead, DNLHead,
|
||||
EMAHead, EncHead, FCNHead, GCHead,
|
||||
NLHead, OCRHead, PSAHead, PSPHead,
|
||||
UPerHead)
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -541,6 +542,46 @@ def test_dw_aspp_head():
|
|||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_dnl_head():
|
||||
# DNL with 'embedded_gaussian' mode
|
||||
head = DNLHead(in_channels=32, channels=16, num_classes=19)
|
||||
assert len(head.convs) == 2
|
||||
assert hasattr(head, 'dnl_block')
|
||||
assert head.dnl_block.temperature == 0.05
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
# NonLocal2d with 'dot_product' mode
|
||||
head = DNLHead(
|
||||
in_channels=32, channels=16, num_classes=19, mode='dot_product')
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
# NonLocal2d with 'gaussian' mode
|
||||
head = DNLHead(
|
||||
in_channels=32, channels=16, num_classes=19, mode='gaussian')
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
# NonLocal2d with 'concatenation' mode
|
||||
head = DNLHead(
|
||||
in_channels=32, channels=16, num_classes=19, mode='concatenation')
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_emanet_head():
|
||||
head = EMAHead(
|
||||
in_channels=32,
|
||||
|
|
Loading…
Reference in New Issue