[Feature] Support EMANet (#34)

* add emanet

* fixed bug and typos

* add emanet config

* fixed padding

* fixed identity

* rename

* rename

* add concat_input

* fallback to update last

* Fixed concat

* update EMANet

* Add tests

* remove self-implement norm

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
pull/37/head^2
Xia Li 李夏 2020-09-07 13:06:59 +08:00 committed by GitHub
parent 3c6dd9e6a4
commit dbca8b44a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 282 additions and 4 deletions

View File

@ -0,0 +1,47 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='EMAHead',
in_channels=2048,
in_index=3,
channels=256,
ema_channels=512,
num_bases=64,
num_stages=3,
momentum=0.1,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
# model training and testing settings
train_cfg = dict()
test_cfg = dict(mode='whole')

View File

@ -0,0 +1,22 @@
# Expectation-Maximization Attention Networks for Semantic Segmentation
## Introduction
```
@inproceedings{li2019expectation,
title={Expectation-maximization attention networks for semantic segmentation},
author={Li, Xia and Zhong, Zhisheng and Wu, Jianlong and Yang, Yibo and Lin, Zhouchen and Liu, Hong},
booktitle={Proceedings of the IEEE International Conference on Computer Vision},
pages={9167--9176},
year={2019}
}
```
## Results and models
### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| EMANet | R-50-D8 | 512x1024 | 80000 | 5.4 | 4.58 | 77.59 | 79.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes_20200901_100301-c43fcef1.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
| EMANet | R-101-D8 | 512x1024 | 80000 | 6.2 | 2.87 | 79.10 | 81.21 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes_20200901_100301-2d970745.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
| EMANet | R-50-D8 | 769x769 | 80000 | 8.9 | 1.97 | 79.33 | 80.49 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes_20200901_100301-16f8de52.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
| EMANet | R-101-D8 | 769x769 | 80000 | 10.1 | 1.22 | 79.62 | 81.00 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes_20200901_100301-47a324ce.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes-20200901_100301.log.json) |

View File

@ -0,0 +1,2 @@
_base_ = './emanet_r50-d8_512x1024_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,2 @@
_base_ = './emanet_r50-d8_769x769_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/emanet_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/emanet_r50-d8.py',
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(align_corners=True),
auxiliary_head=dict(align_corners=True))
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))

View File

@ -2,6 +2,7 @@ from .ann_head import ANNHead
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
from .fpn_head import FPNHead
@ -17,5 +18,5 @@ from .uper_head import UPerHead
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead'
]

View File

@ -0,0 +1,168 @@
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from ..builder import HEADS
from .decode_head import BaseDecodeHead
def reduce_mean(tensor):
"""Reduce mean when distributed training."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor
class EMAModule(nn.Module):
"""Expectation Maximization Attention Module used in EMANet.
Args:
channels (int): Channels of the whole module.
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
"""
def __init__(self, channels, num_bases, num_stages, momentum):
super(EMAModule, self).__init__()
assert num_stages >= 1, 'num_stages must be at least 1!'
self.num_bases = num_bases
self.num_stages = num_stages
self.momentum = momentum
bases = torch.zeros(1, channels, self.num_bases)
bases.normal_(0, math.sqrt(2. / self.num_bases))
# [1, channels, num_bases]
bases = F.normalize(bases, dim=1, p=2)
self.register_buffer('bases', bases)
def forward(self, feats):
"""Forward function."""
batch_size, channels, height, width = feats.size()
# [batch_size, channels, height*width]
feats = feats.view(batch_size, channels, height * width)
# [batch_size, channels, num_bases]
bases = self.bases.repeat(batch_size, 1, 1)
with torch.no_grad():
for i in range(self.num_stages):
# [batch_size, height*width, num_bases]
attention = torch.einsum('bcn,bck->bnk', feats, bases)
attention = F.softmax(attention, dim=2)
# l1 norm
attention_normed = F.normalize(attention, dim=1, p=1)
# [batch_size, channels, num_bases]
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
feats_recon = feats_recon.view(batch_size, channels, height, width)
if self.training:
bases = bases.mean(dim=0, keepdim=True)
bases = reduce_mean(bases)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
self.bases = (1 -
self.momentum) * self.bases + self.momentum * bases
return feats_recon
@HEADS.register_module()
class EMAHead(BaseDecodeHead):
"""Expectation Maximization Attention Networks for Semantic Segmentation.
This head is the implementation of `EMANet
<https://arxiv.org/abs/1907.13426>`_.
Args:
ema_channels (int): EMA module channels
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
concat_input (bool): Whether concat the input and output of convs
before classification layer. Default: True
momentum (float): Momentum to update the base. Default: 0.1.
"""
def __init__(self,
ema_channels,
num_bases,
num_stages,
concat_input=True,
momentum=0.1,
**kwargs):
super(EMAHead, self).__init__(**kwargs)
self.ema_channels = ema_channels
self.num_bases = num_bases
self.num_stages = num_stages
self.concat_input = concat_input
self.momentum = momentum
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
self.num_stages, self.momentum)
self.ema_in_conv = ConvModule(
self.in_channels,
self.ema_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# project (0, inf) -> (-inf, inf)
self.ema_mid_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=None)
for param in self.ema_mid_conv.parameters():
param.requires_grad = False
self.ema_out_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.bottleneck = ConvModule(
self.ema_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
feats = self.ema_in_conv(x)
identity = feats
feats = self.ema_mid_conv(feats)
recon = self.ema_module(feats)
recon = F.relu(recon, inplace=True)
recon = self.ema_out_conv(recon)
output = F.relu(identity + recon, inplace=True)
output = self.bottleneck(output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output

View File

@ -162,6 +162,11 @@ def test_mobilenet_v2_forward():
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')
def test_emanet_forward():
_test_encoder_decoder_forward(
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')
def get_world_size(process_group):
return 1

View File

@ -6,9 +6,9 @@ from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import SyncBatchNorm
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
DepthwiseSeparableASPPHead, EncHead,
FCNHead, GCHead, NLHead, OCRHead,
PSAHead, PSPHead, UPerHead)
DepthwiseSeparableASPPHead, EMAHead,
EncHead, FCNHead, GCHead, NLHead,
OCRHead, PSAHead, PSPHead, UPerHead)
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
@ -539,3 +539,21 @@ def test_dw_aspp_head():
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)
def test_emanet_head():
head = EMAHead(
in_channels=32,
ema_channels=24,
channels=16,
num_stages=3,
num_bases=16,
num_classes=19)
for param in head.ema_mid_conv.parameters():
assert not param.requires_grad
assert hasattr(head, 'ema_module')
inputs = [torch.randn(1, 32, 45, 45)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)