mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support EMANet (#34)
* add emanet * fixed bug and typos * add emanet config * fixed padding * fixed identity * rename * rename * add concat_input * fallback to update last * Fixed concat * update EMANet * Add tests * remove self-implement norm Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
3c6dd9e6a4
commit
dbca8b44a9
47
configs/_base_/models/emanet_r50-d8.py
Normal file
47
configs/_base_/models/emanet_r50-d8.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# model settings
|
||||||
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
|
model = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
pretrained='open-mmlab://resnet50_v1c',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNetV1c',
|
||||||
|
depth=50,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
dilations=(1, 1, 2, 4),
|
||||||
|
strides=(1, 2, 1, 1),
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
norm_eval=False,
|
||||||
|
style='pytorch',
|
||||||
|
contract_dilation=True),
|
||||||
|
decode_head=dict(
|
||||||
|
type='EMAHead',
|
||||||
|
in_channels=2048,
|
||||||
|
in_index=3,
|
||||||
|
channels=256,
|
||||||
|
ema_channels=512,
|
||||||
|
num_bases=64,
|
||||||
|
num_stages=3,
|
||||||
|
momentum=0.1,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||||
|
auxiliary_head=dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=1024,
|
||||||
|
in_index=2,
|
||||||
|
channels=256,
|
||||||
|
num_convs=1,
|
||||||
|
concat_input=False,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg = dict()
|
||||||
|
test_cfg = dict(mode='whole')
|
22
configs/emanet/README.md
Normal file
22
configs/emanet/README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Expectation-Maximization Attention Networks for Semantic Segmentation
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
```
|
||||||
|
@inproceedings{li2019expectation,
|
||||||
|
title={Expectation-maximization attention networks for semantic segmentation},
|
||||||
|
author={Li, Xia and Zhong, Zhisheng and Wu, Jianlong and Yang, Yibo and Lin, Zhouchen and Liu, Hong},
|
||||||
|
booktitle={Proceedings of the IEEE International Conference on Computer Vision},
|
||||||
|
pages={9167--9176},
|
||||||
|
year={2019}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### Cityscapes
|
||||||
|
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||||
|
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| EMANet | R-50-D8 | 512x1024 | 80000 | 5.4 | 4.58 | 77.59 | 79.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes_20200901_100301-c43fcef1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
|
||||||
|
| EMANet | R-101-D8 | 512x1024 | 80000 | 6.2 | 2.87 | 79.10 | 81.21 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes_20200901_100301-2d970745.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
|
||||||
|
| EMANet | R-50-D8 | 769x769 | 80000 | 8.9 | 1.97 | 79.33 | 80.49 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes_20200901_100301-16f8de52.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
|
||||||
|
| EMANet | R-101-D8 | 769x769 | 80000 | 10.1 | 1.22 | 79.62 | 81.00 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes_20200901_100301-47a324ce.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
|
2
configs/emanet/emanet_r101-d8_512x1024_80k_cityscapes.py
Normal file
2
configs/emanet/emanet_r101-d8_512x1024_80k_cityscapes.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
_base_ = './emanet_r50-d8_512x1024_80k_cityscapes.py'
|
||||||
|
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
2
configs/emanet/emanet_r101-d8_769x769_80k_cityscapes.py
Normal file
2
configs/emanet/emanet_r101-d8_769x769_80k_cityscapes.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
_base_ = './emanet_r50-d8_769x769_80k_cityscapes.py'
|
||||||
|
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
4
configs/emanet/emanet_r50-d8_512x1024_80k_cityscapes.py
Normal file
4
configs/emanet/emanet_r50-d8_512x1024_80k_cityscapes.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/emanet_r50-d8.py', '../_base_/datasets/cityscapes.py',
|
||||||
|
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||||
|
]
|
9
configs/emanet/emanet_r50-d8_769x769_80k_cityscapes.py
Normal file
9
configs/emanet/emanet_r50-d8_769x769_80k_cityscapes.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/emanet_r50-d8.py',
|
||||||
|
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
|
||||||
|
'../_base_/schedules/schedule_80k.py'
|
||||||
|
]
|
||||||
|
model = dict(
|
||||||
|
decode_head=dict(align_corners=True),
|
||||||
|
auxiliary_head=dict(align_corners=True))
|
||||||
|
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
|
@ -2,6 +2,7 @@ from .ann_head import ANNHead
|
|||||||
from .aspp_head import ASPPHead
|
from .aspp_head import ASPPHead
|
||||||
from .cc_head import CCHead
|
from .cc_head import CCHead
|
||||||
from .da_head import DAHead
|
from .da_head import DAHead
|
||||||
|
from .ema_head import EMAHead
|
||||||
from .enc_head import EncHead
|
from .enc_head import EncHead
|
||||||
from .fcn_head import FCNHead
|
from .fcn_head import FCNHead
|
||||||
from .fpn_head import FPNHead
|
from .fpn_head import FPNHead
|
||||||
@ -17,5 +18,5 @@ from .uper_head import UPerHead
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
|
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead'
|
||||||
]
|
]
|
||||||
|
168
mmseg/models/decode_heads/ema_head.py
Normal file
168
mmseg/models/decode_heads/ema_head.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
|
from ..builder import HEADS
|
||||||
|
from .decode_head import BaseDecodeHead
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_mean(tensor):
|
||||||
|
"""Reduce mean when distributed training."""
|
||||||
|
if not (dist.is_available() and dist.is_initialized()):
|
||||||
|
return tensor
|
||||||
|
tensor = tensor.clone()
|
||||||
|
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class EMAModule(nn.Module):
|
||||||
|
"""Expectation Maximization Attention Module used in EMANet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): Channels of the whole module.
|
||||||
|
num_bases (int): Number of bases.
|
||||||
|
num_stages (int): Number of the EM iterations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, num_bases, num_stages, momentum):
|
||||||
|
super(EMAModule, self).__init__()
|
||||||
|
assert num_stages >= 1, 'num_stages must be at least 1!'
|
||||||
|
self.num_bases = num_bases
|
||||||
|
self.num_stages = num_stages
|
||||||
|
self.momentum = momentum
|
||||||
|
|
||||||
|
bases = torch.zeros(1, channels, self.num_bases)
|
||||||
|
bases.normal_(0, math.sqrt(2. / self.num_bases))
|
||||||
|
# [1, channels, num_bases]
|
||||||
|
bases = F.normalize(bases, dim=1, p=2)
|
||||||
|
self.register_buffer('bases', bases)
|
||||||
|
|
||||||
|
def forward(self, feats):
|
||||||
|
"""Forward function."""
|
||||||
|
batch_size, channels, height, width = feats.size()
|
||||||
|
# [batch_size, channels, height*width]
|
||||||
|
feats = feats.view(batch_size, channels, height * width)
|
||||||
|
# [batch_size, channels, num_bases]
|
||||||
|
bases = self.bases.repeat(batch_size, 1, 1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(self.num_stages):
|
||||||
|
# [batch_size, height*width, num_bases]
|
||||||
|
attention = torch.einsum('bcn,bck->bnk', feats, bases)
|
||||||
|
attention = F.softmax(attention, dim=2)
|
||||||
|
# l1 norm
|
||||||
|
attention_normed = F.normalize(attention, dim=1, p=1)
|
||||||
|
# [batch_size, channels, num_bases]
|
||||||
|
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
|
||||||
|
# l2 norm
|
||||||
|
bases = F.normalize(bases, dim=1, p=2)
|
||||||
|
|
||||||
|
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
|
||||||
|
feats_recon = feats_recon.view(batch_size, channels, height, width)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
bases = bases.mean(dim=0, keepdim=True)
|
||||||
|
bases = reduce_mean(bases)
|
||||||
|
# l2 norm
|
||||||
|
bases = F.normalize(bases, dim=1, p=2)
|
||||||
|
self.bases = (1 -
|
||||||
|
self.momentum) * self.bases + self.momentum * bases
|
||||||
|
|
||||||
|
return feats_recon
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class EMAHead(BaseDecodeHead):
|
||||||
|
"""Expectation Maximization Attention Networks for Semantic Segmentation.
|
||||||
|
|
||||||
|
This head is the implementation of `EMANet
|
||||||
|
<https://arxiv.org/abs/1907.13426>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ema_channels (int): EMA module channels
|
||||||
|
num_bases (int): Number of bases.
|
||||||
|
num_stages (int): Number of the EM iterations.
|
||||||
|
concat_input (bool): Whether concat the input and output of convs
|
||||||
|
before classification layer. Default: True
|
||||||
|
momentum (float): Momentum to update the base. Default: 0.1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
ema_channels,
|
||||||
|
num_bases,
|
||||||
|
num_stages,
|
||||||
|
concat_input=True,
|
||||||
|
momentum=0.1,
|
||||||
|
**kwargs):
|
||||||
|
super(EMAHead, self).__init__(**kwargs)
|
||||||
|
self.ema_channels = ema_channels
|
||||||
|
self.num_bases = num_bases
|
||||||
|
self.num_stages = num_stages
|
||||||
|
self.concat_input = concat_input
|
||||||
|
self.momentum = momentum
|
||||||
|
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
|
||||||
|
self.num_stages, self.momentum)
|
||||||
|
|
||||||
|
self.ema_in_conv = ConvModule(
|
||||||
|
self.in_channels,
|
||||||
|
self.ema_channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
# project (0, inf) -> (-inf, inf)
|
||||||
|
self.ema_mid_conv = ConvModule(
|
||||||
|
self.ema_channels,
|
||||||
|
self.ema_channels,
|
||||||
|
1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=None,
|
||||||
|
act_cfg=None)
|
||||||
|
for param in self.ema_mid_conv.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
self.ema_out_conv = ConvModule(
|
||||||
|
self.ema_channels,
|
||||||
|
self.ema_channels,
|
||||||
|
1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=None)
|
||||||
|
self.bottleneck = ConvModule(
|
||||||
|
self.ema_channels,
|
||||||
|
self.channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
if self.concat_input:
|
||||||
|
self.conv_cat = ConvModule(
|
||||||
|
self.in_channels + self.channels,
|
||||||
|
self.channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""Forward function."""
|
||||||
|
x = self._transform_inputs(inputs)
|
||||||
|
feats = self.ema_in_conv(x)
|
||||||
|
identity = feats
|
||||||
|
feats = self.ema_mid_conv(feats)
|
||||||
|
recon = self.ema_module(feats)
|
||||||
|
recon = F.relu(recon, inplace=True)
|
||||||
|
recon = self.ema_out_conv(recon)
|
||||||
|
output = F.relu(identity + recon, inplace=True)
|
||||||
|
output = self.bottleneck(output)
|
||||||
|
if self.concat_input:
|
||||||
|
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||||
|
output = self.cls_seg(output)
|
||||||
|
return output
|
@ -162,6 +162,11 @@ def test_mobilenet_v2_forward():
|
|||||||
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')
|
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')
|
||||||
|
|
||||||
|
|
||||||
|
def test_emanet_forward():
|
||||||
|
_test_encoder_decoder_forward(
|
||||||
|
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')
|
||||||
|
|
||||||
|
|
||||||
def get_world_size(process_group):
|
def get_world_size(process_group):
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
@ -6,9 +6,9 @@ from mmcv.cnn import ConvModule
|
|||||||
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
||||||
|
|
||||||
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
||||||
DepthwiseSeparableASPPHead, EncHead,
|
DepthwiseSeparableASPPHead, EMAHead,
|
||||||
FCNHead, GCHead, NLHead, OCRHead,
|
EncHead, FCNHead, GCHead, NLHead,
|
||||||
PSAHead, PSPHead, UPerHead)
|
OCRHead, PSAHead, PSPHead, UPerHead)
|
||||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||||
|
|
||||||
|
|
||||||
@ -539,3 +539,21 @@ def test_dw_aspp_head():
|
|||||||
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
||||||
outputs = head(inputs)
|
outputs = head(inputs)
|
||||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emanet_head():
|
||||||
|
head = EMAHead(
|
||||||
|
in_channels=32,
|
||||||
|
ema_channels=24,
|
||||||
|
channels=16,
|
||||||
|
num_stages=3,
|
||||||
|
num_bases=16,
|
||||||
|
num_classes=19)
|
||||||
|
for param in head.ema_mid_conv.parameters():
|
||||||
|
assert not param.requires_grad
|
||||||
|
assert hasattr(head, 'ema_module')
|
||||||
|
inputs = [torch.randn(1, 32, 45, 45)]
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
head, inputs = to_cuda(head, inputs)
|
||||||
|
outputs = head(inputs)
|
||||||
|
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user