[Feature] Support EMANet (#34)
* add emanet * fixed bug and typos * add emanet config * fixed padding * fixed identity * rename * rename * add concat_input * fallback to update last * Fixed concat * update EMANet * Add tests * remove self-implement norm Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>pull/37/head^2
parent
3c6dd9e6a4
commit
dbca8b44a9
|
@ -0,0 +1,47 @@
|
|||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='EMAHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=256,
|
||||
ema_channels=512,
|
||||
num_bases=64,
|
||||
num_stages=3,
|
||||
momentum=0.1,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
|
||||
# model training and testing settings
|
||||
train_cfg = dict()
|
||||
test_cfg = dict(mode='whole')
|
|
@ -0,0 +1,22 @@
|
|||
# Expectation-Maximization Attention Networks for Semantic Segmentation
|
||||
|
||||
## Introduction
|
||||
```
|
||||
@inproceedings{li2019expectation,
|
||||
title={Expectation-maximization attention networks for semantic segmentation},
|
||||
author={Li, Xia and Zhong, Zhisheng and Wu, Jianlong and Yang, Yibo and Lin, Zhouchen and Liu, Hong},
|
||||
booktitle={Proceedings of the IEEE International Conference on Computer Vision},
|
||||
pages={9167--9176},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| EMANet | R-50-D8 | 512x1024 | 80000 | 5.4 | 4.58 | 77.59 | 79.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes_20200901_100301-c43fcef1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
|
||||
| EMANet | R-101-D8 | 512x1024 | 80000 | 6.2 | 2.87 | 79.10 | 81.21 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes_20200901_100301-2d970745.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
|
||||
| EMANet | R-50-D8 | 769x769 | 80000 | 8.9 | 1.97 | 79.33 | 80.49 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes_20200901_100301-16f8de52.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
|
||||
| EMANet | R-101-D8 | 769x769 | 80000 | 10.1 | 1.22 | 79.62 | 81.00 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes_20200901_100301-47a324ce.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './emanet_r50-d8_512x1024_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './emanet_r50-d8_769x769_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'../_base_/models/emanet_r50-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||
]
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/emanet_r50-d8.py',
|
||||
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(align_corners=True),
|
||||
auxiliary_head=dict(align_corners=True))
|
||||
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
|
|
@ -2,6 +2,7 @@ from .ann_head import ANNHead
|
|||
from .aspp_head import ASPPHead
|
||||
from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
from .ema_head import EMAHead
|
||||
from .enc_head import EncHead
|
||||
from .fcn_head import FCNHead
|
||||
from .fpn_head import FPNHead
|
||||
|
@ -17,5 +18,5 @@ from .uper_head import UPerHead
|
|||
__all__ = [
|
||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
def reduce_mean(tensor):
|
||||
"""Reduce mean when distributed training."""
|
||||
if not (dist.is_available() and dist.is_initialized()):
|
||||
return tensor
|
||||
tensor = tensor.clone()
|
||||
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
||||
return tensor
|
||||
|
||||
|
||||
class EMAModule(nn.Module):
|
||||
"""Expectation Maximization Attention Module used in EMANet.
|
||||
|
||||
Args:
|
||||
channels (int): Channels of the whole module.
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_bases, num_stages, momentum):
|
||||
super(EMAModule, self).__init__()
|
||||
assert num_stages >= 1, 'num_stages must be at least 1!'
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.momentum = momentum
|
||||
|
||||
bases = torch.zeros(1, channels, self.num_bases)
|
||||
bases.normal_(0, math.sqrt(2. / self.num_bases))
|
||||
# [1, channels, num_bases]
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = feats.size()
|
||||
# [batch_size, channels, height*width]
|
||||
feats = feats.view(batch_size, channels, height * width)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = self.bases.repeat(batch_size, 1, 1)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(self.num_stages):
|
||||
# [batch_size, height*width, num_bases]
|
||||
attention = torch.einsum('bcn,bck->bnk', feats, bases)
|
||||
attention = F.softmax(attention, dim=2)
|
||||
# l1 norm
|
||||
attention_normed = F.normalize(attention, dim=1, p=1)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
|
||||
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
|
||||
feats_recon = feats_recon.view(batch_size, channels, height, width)
|
||||
|
||||
if self.training:
|
||||
bases = bases.mean(dim=0, keepdim=True)
|
||||
bases = reduce_mean(bases)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.bases = (1 -
|
||||
self.momentum) * self.bases + self.momentum * bases
|
||||
|
||||
return feats_recon
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class EMAHead(BaseDecodeHead):
|
||||
"""Expectation Maximization Attention Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EMANet
|
||||
<https://arxiv.org/abs/1907.13426>`_.
|
||||
|
||||
Args:
|
||||
ema_channels (int): EMA module channels
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer. Default: True
|
||||
momentum (float): Momentum to update the base. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ema_channels,
|
||||
num_bases,
|
||||
num_stages,
|
||||
concat_input=True,
|
||||
momentum=0.1,
|
||||
**kwargs):
|
||||
super(EMAHead, self).__init__(**kwargs)
|
||||
self.ema_channels = ema_channels
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.concat_input = concat_input
|
||||
self.momentum = momentum
|
||||
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
|
||||
self.num_stages, self.momentum)
|
||||
|
||||
self.ema_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.ema_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# project (0, inf) -> (-inf, inf)
|
||||
self.ema_mid_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
for param in self.ema_mid_conv.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.ema_out_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.bottleneck = ConvModule(
|
||||
self.ema_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.ema_in_conv(x)
|
||||
identity = feats
|
||||
feats = self.ema_mid_conv(feats)
|
||||
recon = self.ema_module(feats)
|
||||
recon = F.relu(recon, inplace=True)
|
||||
recon = self.ema_out_conv(recon)
|
||||
output = F.relu(identity + recon, inplace=True)
|
||||
output = self.bottleneck(output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
|
@ -162,6 +162,11 @@ def test_mobilenet_v2_forward():
|
|||
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')
|
||||
|
||||
|
||||
def test_emanet_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')
|
||||
|
||||
|
||||
def get_world_size(process_group):
|
||||
|
||||
return 1
|
||||
|
|
|
@ -6,9 +6,9 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
||||
|
||||
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
||||
DepthwiseSeparableASPPHead, EncHead,
|
||||
FCNHead, GCHead, NLHead, OCRHead,
|
||||
PSAHead, PSPHead, UPerHead)
|
||||
DepthwiseSeparableASPPHead, EMAHead,
|
||||
EncHead, FCNHead, GCHead, NLHead,
|
||||
OCRHead, PSAHead, PSPHead, UPerHead)
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -539,3 +539,21 @@ def test_dw_aspp_head():
|
|||
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_emanet_head():
|
||||
head = EMAHead(
|
||||
in_channels=32,
|
||||
ema_channels=24,
|
||||
channels=16,
|
||||
num_stages=3,
|
||||
num_bases=16,
|
||||
num_classes=19)
|
||||
for param in head.ema_mid_conv.parameters():
|
||||
assert not param.requires_grad
|
||||
assert hasattr(head, 'ema_module')
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
|
Loading…
Reference in New Issue