[CodeCamp2023-367] Add pp_mobileseg model (#3239)

This commit is contained in:
Yang-ChangHui 2023-08-09 23:57:01 +08:00 committed by GitHub
parent 817c18bf2c
commit 1e937961b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1326 additions and 0 deletions

View File

@ -0,0 +1,58 @@
# PP-MobileSeg: Exploring Transformer Blocks for Efficient Mobile Segmentation.
## Reference
> [PP-MobileSeg: Explore the Fast and Accurate Semantic Segmentation Model on Mobile Devices. ](https://arxiv.org/abs/2304.05152)
## Introduction
<a href="https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.8">Official Repo</a>
<a href="https://github.com/open-mmlab/mmsegmentation/tree/main/projects/pp_mobileseg">Code Snippet</a>
## <img src="https://user-images.githubusercontent.com/34859558/190043857-bfbdaf8b-d2dc-4fff-81c7-e0aac50851f9.png" width="25"/> Abstract
With the success of transformers in computer vision, several attempts have been made to adapt transformers to mobile devices. However, their performance is not satisfied for some real world applications. Therefore, we propose PP-MobileSeg, a SOTA semantic segmentation model for mobile devices.
It is composed of three newly proposed parts, the strideformer backbone, the Aggregated Attention Module(AAM), and the Valid Interpolate Module(VIM):
- With the four-stage MobileNetV3 block as the feature extractor, we manage to extract rich local features of different receptive fields with little parameter overhead. Also, we further efficiently empower features from the last two stages with the global view using strided sea attention.
- To effectively fuse the features, we use AAM to filter the detail features with ensemble voting and add the semantic feature to it to enhance the semantic information to the most content.
- At last, we use VIM to upsample the downsampled feature to the original resolution and significantly decrease latency in model inference stage. It only interpolates classes present in the final prediction which only takes around 10% in the ADE20K dataset. This is a common scenario for datasets with large classes. Therefore it significantly decreases the latency of the final upsample process which takes the greatest part of the model's overall latency.
Extensive experiments show that PP-MobileSeg achieves a superior params-accuracy-latency tradeoff compared to other SOTA methods.
<div align="center">
<img src="https://user-images.githubusercontent.com/34859558/227450728-1338fcb1-3b8a-4453-a155-da60abcacb88.png" width = "1000" />
</div>
## <img src="https://user-images.githubusercontent.com/34859558/190044217-8f6befc2-7f20-473d-b356-148e06265205.png" width="25"/> Performance
### ADE20K
| Model | Backbone | Training Iters | Batchsize | Train Resolution | mIoU(%) | latency(ms)\* | params(M) | config | Links |
| ----------------- | ----------------- | -------------- | --------- | ---------------- | ------- | ------------- | --------- | ------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| PP-MobileSeg-Base | StrideFormer-Base | 80000 | 32 | 512x512 | 41.57% | 265.5 | 5.62 | [config](https://github.com/Yang-Changhui/mmsegmentation/tree/add_ppmobileseg/projects/pp_mobileseg/configs/pp_mobileseg) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base-ed0be681.pth)\|[log](https://bj.bcebos.com/paddleseg/dygraph/ade20k/pp_mobileseg_base/train.log) |
| PP-MobileSeg-Tiny | StrideFormer-Tiny | 80000 | 32 | 512x512 | 36.39% | 215.3 | 1.61 | [config](https://github.com/Yang-Changhui/mmsegmentation/tree/add_ppmobileseg/projects/pp_mobileseg/configs/pp_mobileseg) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny-e4b35e96.pth)\|[log](https://bj.bcebos.com/paddleseg/dygraph/ade20k/pp_mobileseg_tiny/train.log) |
## Citation
If you find our project useful in your research, please consider citing:
```
@misc{liu2021paddleseg,
title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation},
author={Yi Liu and Lutao Chu and Guowei Chen and Zewu Wu and Zeyu Chen and Baohua Lai and Yuying Hao},
year={2021},
eprint={2101.06175},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@misc{paddleseg2019,
title={PaddleSeg, End-to-end image segmentation kit based on PaddlePaddle},
author={PaddlePaddle Contributors},
howpublished = {\url{https://github.com/PaddlePaddle/PaddleSeg}},
year={2019}
}
```

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .strideformer import StrideFormer
__all__ = ['StrideFormer']

View File

@ -0,0 +1,958 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer
from mmcv.cnn.bricks.transformer import build_dropout
from mmengine.logging import print_log
from mmengine.model import BaseModule
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from mmseg.registry import MODELS
@MODELS.register_module()
class StrideFormer(BaseModule):
"""The StrideFormer implementation based on torch.
The original article refers to:https://arxiv.org/abs/2304.05152
Args:
mobileV3_cfg(list): Each sublist describe the config for a
MobileNetV3 block.
channels(list): The input channels for each MobileNetV3 block.
embed_dims(list): The channels of the features input to the sea
attention block.
key_dims(list, optional): The embeding dims for each head in
attention.
depths(list, optional): describes the depth of the attention block.
i,e: M,N.
num_heads(int, optional): The number of heads of the attention
blocks.
attn_ratios(int, optional): The expand ratio of V.
mlp_ratios(list, optional): The ratio of mlp blocks.
drop_path_rate(float, optional): The drop path rate in attention
block.
act_cfg(dict, optional): The activation layer of AAM:
Aggregate Attention Module.
inj_type(string, optional): The type of injection/AAM.
out_channels(int, optional): The output channels of the AAM.
dims(list, optional): The dimension of the fusion block.
out_feat_chs(list, optional): The input channels of the AAM.
stride_attention(bool, optional): whether to stride attention in
each attention layer.
pretrained(str, optional): the path of pretrained model.
"""
def __init__(
self,
mobileV3_cfg,
channels,
embed_dims,
key_dims=[16, 24],
depths=[2, 2],
num_heads=8,
attn_ratios=2,
mlp_ratios=[2, 4],
drop_path_rate=0.1,
act_cfg=dict(type='ReLU'),
inj_type='AAM',
out_channels=256,
dims=(128, 160),
out_feat_chs=None,
stride_attention=True,
pretrained=None,
init_cfg=None,
):
super().__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained
), 'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.cfgs = mobileV3_cfg
self.dims = dims
for i in range(len(self.cfgs)):
smb = StackedMV3Block(
cfgs=self.cfgs[i],
stem=True if i == 0 else False,
in_channels=channels[i],
)
setattr(self, f'smb{i + 1}', smb)
for i in range(len(depths)):
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depths[i])
]
trans = BasicLayer(
block_num=depths[i],
embedding_dim=embed_dims[i],
key_dim=key_dims[i],
num_heads=num_heads,
mlp_ratio=mlp_ratios[i],
attn_ratio=attn_ratios,
drop=0,
attn_drop=0.0,
drop_path=dpr,
act_cfg=act_cfg,
stride_attention=stride_attention,
)
setattr(self, f'trans{i + 1}', trans)
self.inj_type = inj_type
if self.inj_type == 'AAM':
self.inj_module = InjectionMultiSumallmultiallsum(
in_channels=out_feat_chs, out_channels=out_channels)
self.feat_channels = [
out_channels,
]
elif self.inj_type == 'AAMSx8':
self.inj_module = InjectionMultiSumallmultiallsumSimpx8(
in_channels=out_feat_chs, out_channels=out_channels)
self.feat_channels = [
out_channels,
]
elif self.inj_type == 'origin':
for i in range(len(dims)):
fuse = FusionBlock(
out_feat_chs[0] if i == 0 else dims[i - 1],
out_feat_chs[i + 1],
embed_dim=dims[i],
act_cfg=None,
)
setattr(self, f'fuse{i + 1}', fuse)
self.feat_channels = [
dims[i],
]
else:
raise NotImplementedError(self.inj_module + ' is not implemented')
self.pretrained = pretrained
# self.init_weights()
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
print_log(msg=f'Resize the pos_embed shape from '
f'{state_dict["pos_embed"].shape} to '
f'{self.pos_embed.shape}')
h, w = self.img_size
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'],
(h // self.patch_size, w // self.patch_size),
(pos_size, pos_size),
self.interpolate_mode,
)
load_state_dict(self, state_dict, strict=False, logger=None)
def forward(self, x):
x_hw = x.shape[2:]
outputs = []
num_smb_stage = len(self.cfgs)
num_trans_stage = len(self.depths)
for i in range(num_smb_stage):
smb = getattr(self, f'smb{i + 1}')
x = smb(x)
# 1/8 shared feat
if i == 1:
outputs.append(x)
if num_trans_stage + i >= num_smb_stage:
trans = getattr(
self, f'trans{i + num_trans_stage - num_smb_stage + 1}')
x = trans(x)
outputs.append(x)
if self.inj_type == 'origin':
x_detail = outputs[0]
for i in range(len(self.dims)):
fuse = getattr(self, f'fuse{i + 1}')
x_detail = fuse(x_detail, outputs[i + 1])
output = x_detail
else:
output = self.inj_module(outputs)
return [output, x_hw]
class StackedMV3Block(nn.Module):
"""The MobileNetV3 block.
Args:
cfgs (list): The MobileNetV3 config list of a stage.
stem (bool): Whether is the first stage or not.
in_channels (int, optional): The channels of input image. Default: 3.
scale: float=1.0.
The coefficient that controls the size of network parameters.
Returns:
model: nn.Module.
A stage of specific MobileNetV3 model depends on args.
"""
def __init__(self,
cfgs,
stem,
in_channels,
scale=1.0,
norm_cfg=dict(type='BN')):
super().__init__()
self.scale = scale
self.stem = stem
if self.stem:
self.conv = ConvModule(
in_channels=3,
out_channels=_make_divisible(in_channels * self.scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='HSwish'),
)
self.blocks = nn.ModuleList()
for i, (k, exp, c, se, act, s) in enumerate(cfgs):
self.blocks.append(
ResidualUnit(
in_channel=_make_divisible(in_channels * self.scale),
mid_channel=_make_divisible(self.scale * exp),
out_channel=_make_divisible(self.scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=act,
dilation=1,
))
in_channels = _make_divisible(self.scale * c)
def forward(self, x):
if self.stem:
x = self.conv(x)
for i, block in enumerate(self.blocks):
x = block(x)
return x
class ResidualUnit(nn.Module):
"""The Residual module.
Args:
in_channel (int, optional): The channels of input feature.
mid_channel (int, optional): The channels of middle process.
out_channel (int, optional): The channels of output feature.
kernel_size (int, optional): The size of the convolving kernel.
stride (int, optional): The stride size.
use_se (bool, optional): if to use the SEModule.
act (string, optional): activation layer.
dilation (int, optional): The dilation size.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
"""
def __init__(
self,
in_channel,
mid_channel,
out_channel,
kernel_size,
stride,
use_se,
act=None,
dilation=1,
norm_cfg=dict(type='BN'),
):
super().__init__()
self.if_shortcut = stride == 1 and in_channel == out_channel
self.if_se = use_se
self.expand_conv = ConvModule(
in_channels=in_channel,
out_channels=mid_channel,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=dict(type=act) if act is not None else None,
)
self.bottleneck_conv = ConvModule(
in_channels=mid_channel,
out_channels=mid_channel,
kernel_size=kernel_size,
stride=stride,
padding=int((kernel_size - 1) // 2) * dilation,
bias=False,
groups=mid_channel,
dilation=dilation,
norm_cfg=norm_cfg,
act_cfg=dict(type=act) if act is not None else None,
)
if self.if_se:
self.mid_se = SEModule(mid_channel)
self.linear_conv = ConvModule(
in_channels=mid_channel,
out_channels=out_channel,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
)
def forward(self, x):
identity = x
x = self.expand_conv(x)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = torch.add(identity, x)
return x
class SEModule(nn.Module):
"""SE Module.
Args:
channel (int, optional): The channels of input feature.
reduction (int, optional): The channel reduction rate.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
"""
def __init__(self, channel, reduction=4, act_cfg=dict(type='ReLU')):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_act1 = ConvModule(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
norm_cfg=None,
act_cfg=act_cfg,
)
self.conv_act2 = ConvModule(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
norm_cfg=None,
act_cfg=dict(type='Hardsigmoid', slope=0.2, offset=0.5),
)
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv_act1(x)
x = self.conv_act2(x)
return torch.mul(identity, x)
class BasicLayer(nn.Module):
"""The transformer basic layer.
Args:
block_num (int): the block nums of the transformer basic layer.
embedding_dim (int): The feature dimension.
key_dim (int): the key dim.
num_heads (int): Parallel attention heads.
mlp_ratio (float): the mlp ratio.
attn_ratio (float): the attention ratio.
drop (float): Probability of an element to be zeroed
after the feed forward layer.Default: 0.0.
attn_drop (float): The drop out rate for attention layer.
Default: 0.0.
drop_path (float): stochastic depth rate. Default 0.0.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
stride_attention (bool, optional): whether to stride attention in
each attention layer.
"""
def __init__(
self,
block_num,
embedding_dim,
key_dim,
num_heads,
mlp_ratio=4.0,
attn_ratio=2.0,
drop=0.0,
attn_drop=0.0,
drop_path=None,
act_cfg=None,
stride_attention=None,
):
super().__init__()
self.block_num = block_num
self.transformer_blocks = nn.ModuleList()
for i in range(self.block_num):
self.transformer_blocks.append(
Block(
embedding_dim,
key_dim=key_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
attn_ratio=attn_ratio,
drop=drop,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path,
act_cfg=act_cfg,
stride_attention=stride_attention,
))
def forward(self, x):
for i in range(self.block_num):
x = self.transformer_blocks[i](x)
return x
class Block(nn.Module):
"""the block of the transformer basic layer.
Args:
dim (int): The feature dimension.
key_dim (int): The key dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (float): the mlp ratio.
attn_ratio (float): the attention ratio.
drop (float): Probability of an element to be zeroed
after the feed forward layer.Default: 0.0.
drop_path (float): stochastic depth rate. Default 0.0.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
stride_attention (bool, optional): whether to stride attention in
each attention layer.
"""
def __init__(
self,
dim,
key_dim,
num_heads,
mlp_ratio=4.0,
attn_ratio=2.0,
drop=0.0,
drop_path=0.0,
act_cfg=None,
stride_attention=None,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.attn = SeaAttention(
dim,
key_dim=key_dim,
num_heads=num_heads,
attn_ratio=attn_ratio,
act_cfg=act_cfg,
stride_attention=stride_attention,
)
self.drop_path = (
build_dropout(dict(type='DropPath', drop_prob=drop_path))
if drop_path > 0.0 else nn.Identity())
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop,
)
def forward(self, x1):
x1 = x1 + self.drop_path(self.attn(x1))
x1 = x1 + self.drop_path(self.mlp(x1))
return x1
class SqueezeAxialPositionalEmbedding(nn.Module):
"""the Squeeze Axial Positional Embedding.
Args:
dim (int): The feature dimension.
shape (int): The patch size.
"""
def __init__(self, dim, shape):
super().__init__()
self.pos_embed = nn.init.normal_(
nn.Parameter(torch.zeros(1, dim, shape)))
def forward(self, x):
B, C, N = x.shape
x = x + F.interpolate(
self.pos_embed, size=(N, ), mode='linear', align_corners=False)
return x
class SeaAttention(nn.Module):
"""The sea attention.
Args:
dim (int): The feature dimension.
key_dim (int): The key dimension.
num_heads (int): number of attention heads.
attn_ratio (float): the attention ratio.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
stride_attention (bool, optional): whether to stride attention in
each attention layer.
"""
def __init__(
self,
dim,
key_dim,
num_heads,
attn_ratio=4.0,
act_cfg=None,
norm_cfg=dict(type='BN'),
stride_attention=False,
):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.to_q = ConvModule(
dim, nh_kd, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
self.to_k = ConvModule(
dim, nh_kd, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
self.to_v = ConvModule(
dim, self.dh, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
self.stride_attention = stride_attention
if self.stride_attention:
self.stride_conv = ConvModule(
dim,
dim,
kernel_size=3,
stride=2,
padding=1,
bias=True,
groups=dim,
norm_cfg=norm_cfg,
act_cfg=None,
)
self.proj = ConvModule(
self.dh,
dim,
1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('act', 'conv', 'norm'),
)
self.proj_encode_row = ConvModule(
self.dh,
self.dh,
1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('act', 'conv', 'norm'),
)
self.pos_emb_rowq = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.pos_emb_rowk = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.proj_encode_column = ConvModule(
self.dh,
self.dh,
1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('act', 'conv', 'norm'),
)
self.pos_emb_columnq = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.pos_emb_columnk = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.dwconv = ConvModule(
2 * self.dh,
2 * self.dh,
3,
padding=1,
groups=2 * self.dh,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
self.pwconv = ConvModule(
2 * self.dh, dim, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
self.sigmoid = build_activation_layer(dict(type='HSigmoid'))
def forward(self, x):
B, C, H_ori, W_ori = x.shape
if self.stride_attention:
x = self.stride_conv(x)
B, C, H, W = x.shape
q = self.to_q(x) # [B, nhead*dim, H, W]
k = self.to_k(x)
v = self.to_v(x)
qkv = torch.cat([q, k, v], dim=1)
qkv = self.dwconv(qkv)
qkv = self.pwconv(qkv)
qrow = (self.pos_emb_rowq(q.mean(-1)).reshape(
[B, self.num_heads, -1, H]).permute(
(0, 1, 3, 2))) # [B, nhead, H, dim]
krow = self.pos_emb_rowk(k.mean(-1)).reshape(
[B, self.num_heads, -1, H]) # [B, nhead, dim, H]
vrow = (v.mean(-1).reshape([B, self.num_heads, -1,
H]).permute([0, 1, 3, 2])
) # [B, nhead, H, dim*attn_ratio]
attn_row = torch.matmul(qrow, krow) * self.scale # [B, nhead, H, H]
attn_row = nn.functional.softmax(attn_row, dim=-1)
xx_row = torch.matmul(attn_row, vrow) # [B, nhead, H, dim*attn_ratio]
xx_row = self.proj_encode_row(
xx_row.permute([0, 1, 3, 2]).reshape([B, self.dh, H, 1]))
# squeeze column
qcolumn = (
self.pos_emb_columnq(q.mean(-2)).reshape(
[B, self.num_heads, -1, W]).permute([0, 1, 3, 2]))
kcolumn = self.pos_emb_columnk(k.mean(-2)).reshape(
[B, self.num_heads, -1, W])
vcolumn = (
torch.mean(v, -2).reshape([B, self.num_heads, -1,
W]).permute([0, 1, 3, 2]))
attn_column = torch.matmul(qcolumn, kcolumn) * self.scale
attn_column = nn.functional.softmax(attn_column, dim=-1)
xx_column = torch.matmul(attn_column, vcolumn) # B nH W C
xx_column = self.proj_encode_column(
xx_column.permute([0, 1, 3, 2]).reshape([B, self.dh, 1, W]))
xx = torch.add(xx_row, xx_column) # [B, self.dh, H, W]
xx = torch.add(v, xx)
xx = self.proj(xx)
xx = self.sigmoid(xx) * qkv
if self.stride_attention:
xx = F.interpolate(xx, size=(H_ori, W_ori), mode='bilinear')
return xx
class MLP(nn.Module):
"""the Multilayer Perceptron.
Args:
in_features (int): the input feature.
hidden_features (int): the hidden feature.
out_features (int): the output feature.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
drop (float): Probability of an element to be zeroed.
Default 0.0
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=None,
norm_cfg=dict(type='BN'),
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = ConvModule(
in_features,
hidden_features,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
)
self.dwconv = ConvModule(
hidden_features,
hidden_features,
kernel_size=3,
padding=1,
groups=hidden_features,
norm_cfg=None,
act_cfg=act_cfg,
)
self.fc2 = ConvModule(
hidden_features,
out_features,
1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
)
self.drop = build_dropout(dict(type='Dropout', drop_prob=drop))
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class FusionBlock(nn.Module):
"""The feature fusion block.
Args:
in_channel (int): the input channel.
out_channel (int): the output channel.
embed_dim (int): embedding dimension.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
"""
def __init__(
self,
in_channel,
out_channel,
embed_dim,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
) -> None:
super().__init__()
self.local_embedding = ConvModule(
in_channels=in_channel,
out_channels=embed_dim,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
)
self.global_act = ConvModule(
in_channels=out_channel,
out_channels=embed_dim,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg if act_cfg is not None else None,
)
def forward(self, x_l, x_g):
"""
x_g: global features
x_l: local features
"""
B, C, H, W = x_l.shape
local_feat = self.local_embedding(x_l)
global_act = self.global_act(x_g)
sig_act = F.interpolate(
global_act, size=(H, W), mode='bilinear', align_corners=False)
out = local_feat * sig_act
return out
class InjectionMultiSumallmultiallsum(nn.Module):
"""the Aggregate Attention Module.
Args:
in_channels (tuple): the input channel.
out_channels (int): the output channel.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
"""
def __init__(
self,
in_channels=(64, 128, 256, 384),
out_channels=256,
act_cfg=dict(type='Sigmoid'),
norm_cfg=dict(type='BN'),
):
super().__init__()
self.embedding_list = nn.ModuleList()
self.act_embedding_list = nn.ModuleList()
self.act_list = nn.ModuleList()
for i in range(len(in_channels)):
self.embedding_list.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
))
self.act_embedding_list.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
))
def forward(self, inputs): # x_x8, x_x16, x_x32, x_x64
low_feat1 = F.interpolate(inputs[0], scale_factor=0.5, mode='bilinear')
low_feat1_act = self.act_embedding_list[0](low_feat1)
low_feat1 = self.embedding_list[0](low_feat1)
low_feat2 = F.interpolate(
inputs[1], size=low_feat1.shape[-2:], mode='bilinear')
low_feat2_act = self.act_embedding_list[1](low_feat2) # x16
low_feat2 = self.embedding_list[1](low_feat2)
high_feat_act = F.interpolate(
self.act_embedding_list[2](inputs[2]),
size=low_feat2.shape[2:],
mode='bilinear',
)
high_feat = F.interpolate(
self.embedding_list[2](inputs[2]),
size=low_feat2.shape[2:],
mode='bilinear')
res = (
low_feat1_act * low_feat2_act * high_feat_act *
(low_feat1 + low_feat2) + high_feat)
return res
class InjectionMultiSumallmultiallsumSimpx8(nn.Module):
"""the Aggregate Attention Module.
Args:
in_channels (tuple): the input channel.
out_channels (int): the output channel.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
"""
def __init__(
self,
in_channels=(64, 128, 256, 384),
out_channels=256,
act_cfg=dict(type='Sigmoid'),
norm_cfg=dict(type='BN'),
):
super().__init__()
self.embedding_list = nn.ModuleList()
self.act_embedding_list = nn.ModuleList()
self.act_list = nn.ModuleList()
for i in range(len(in_channels)):
if i != 1:
self.embedding_list.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
))
if i != 0:
self.act_embedding_list.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
))
def forward(self, inputs):
# x_x8, x_x16, x_x32
low_feat1 = self.embedding_list[0](inputs[0])
low_feat2 = F.interpolate(
inputs[1], size=low_feat1.shape[-2:], mode='bilinear')
low_feat2_act = self.act_embedding_list[0](low_feat2)
high_feat_act = F.interpolate(
self.act_embedding_list[1](inputs[2]),
size=low_feat2.shape[2:],
mode='bilinear',
)
high_feat = F.interpolate(
self.embedding_list[1](inputs[2]),
size=low_feat2.shape[2:],
mode='bilinear')
res = low_feat2_act * high_feat_act * low_feat1 + high_feat
return res
def _make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
@MODELS.register_module()
class Hardsigmoid(nn.Module):
"""the hardsigmoid activation.
Args:
slope (float, optional): The slope of hardsigmoid function.
Default is 0.1666667.
offset (float, optional): The offset of hardsigmoid function.
Default is 0.5.
inplace (bool): can optionally do the operation in-place.
Default: ``False``
"""
def __init__(self, slope=0.1666667, offset=0.5, inplace=False):
super().__init__()
self.slope = slope
self.offset = offset
def forward(self, x):
return (x * self.slope + self.offset).clamp(0, 1)

View File

@ -0,0 +1,68 @@
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(
type='RandomResize',
scale=(2048, 512),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -0,0 +1,15 @@
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False
tta_model = dict(type='SegTTAModel')

View File

@ -0,0 +1,47 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
# pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='StrideFormer',
mobileV3_cfg=[
# k t c, s
[[3, 16, 16, True, 'ReLU', 1], [3, 64, 32, False, 'ReLU', 2],
[3, 96, 32, False, 'ReLU', 1]], # cfg1
[[5, 128, 64, True, 'HSwish', 2], [5, 240, 64, True, 'HSwish',
1]], # cfg2
[[5, 384, 128, True, 'HSwish', 2],
[5, 384, 128, True, 'HSwish', 1]], # cfg3
[[5, 768, 192, True, 'HSwish', 2],
[5, 768, 192, True, 'HSwish', 1]], # cfg4
],
channels=[16, 32, 64, 128, 192],
depths=[3, 3],
embed_dims=[128, 192],
num_heads=8,
inj_type='AAMSx8',
out_feat_chs=[64, 128, 192],
act_cfg=dict(type='ReLU6'),
),
decode_head=dict(
type='PPMobileSegHead',
num_classes=150,
in_channels=256,
dropout_ratio=0.1,
use_dw=True,
act_cfg=dict(type='ReLU'),
align_corners=False),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,24 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=80000,
by_epoch=False)
]
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

View File

@ -0,0 +1,13 @@
_base_ = [
'../_base_/models/pp_mobile.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
checkpoint = './models/pp_mobile_base.pth'
crop_size = (512, 512)
data_preprocessor = dict(size=crop_size, test_cfg=dict(size_divisor=32))
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint), ),
decode_head=dict(num_classes=150, upsample='intepolate'),
)

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/pp_mobile.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
checkpoint = './models/pp_mobile_tiny.pth'
crop_size = (512, 512)
data_preprocessor = dict(size=crop_size, test_cfg=dict(size_divisor=32))
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint=checkpoint),
type='StrideFormer',
mobileV3_cfg=[
# k t c, s
[[3, 16, 16, True, 'ReLU', 1], [3, 64, 32, False, 'ReLU', 2],
[3, 48, 24, False, 'ReLU', 1]], # cfg1
[[5, 96, 32, True, 'HSwish', 2], [5, 96, 32, True, 'HSwish',
1]], # cfg2
[[5, 160, 64, True, 'HSwish', 2], [5, 160, 64, True, 'HSwish',
1]], # cfg3
[[3, 384, 128, True, 'HSwish', 2],
[3, 384, 128, True, 'HSwish', 1]], # cfg4
],
channels=[16, 24, 32, 64, 128],
depths=[2, 2],
embed_dims=[64, 128],
num_heads=4,
inj_type='AAM',
out_feat_chs=[32, 64, 128],
act_cfg=dict(type='ReLU6'),
),
decode_head=dict(
num_classes=150,
in_channels=256,
use_dw=True,
act_cfg=dict(type='ReLU'),
upsample='intepolate'),
)

View File

@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .pp_mobileseg_head import PPMobileSegHead
__all__ = [
'PPMobileSegHead',
]

View File

@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_conv_layer
from torch import Tensor
from mmseg.registry import MODELS
@MODELS.register_module()
class PPMobileSegHead(nn.Module):
"""the segmentation head.
Args:
num_classes (int): the classes num.
in_channels (int): the input channels.
use_dw (bool): if to use deepwith convolution.
dropout_ratio (float): Probability of an element to be zeroed.
Default 0.0
align_corners (bool, optional): Geometrically, we consider the pixels
of the input and output as squares rather than points.
upsample (str): the upsample method.
out_channels (int): the output channel.
conv_cfg (dict): Config dict for convolution layer.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
"""
def __init__(self,
num_classes,
in_channels,
use_dw=True,
dropout_ratio=0.1,
align_corners=False,
upsample='intepolate',
out_channels=None,
conv_cfg=dict(type='Conv'),
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN')):
super().__init__()
self.align_corners = align_corners
self.last_channels = in_channels
self.upsample = upsample
self.num_classes = num_classes
self.out_channels = out_channels
self.linear_fuse = ConvModule(
in_channels=self.last_channels,
out_channels=self.last_channels,
kernel_size=1,
bias=False,
groups=self.last_channels if use_dw else 1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.dropout = nn.Dropout2d(dropout_ratio)
self.conv_seg = build_conv_layer(
conv_cfg, self.last_channels, self.num_classes, kernel_size=1)
def forward(self, x):
x, x_hw = x[0], x[1]
x = self.linear_fuse(x)
x = self.dropout(x)
x = self.conv_seg(x)
if self.upsample == 'intepolate' or self.training or \
self.num_classes < 30:
x = F.interpolate(
x, x_hw, mode='bilinear', align_corners=self.align_corners)
elif self.upsample == 'vim':
labelset = torch.unique(torch.argmax(x, 1))
x = torch.gather(x, 1, labelset)
x = F.interpolate(
x, x_hw, mode='bilinear', align_corners=self.align_corners)
pred = torch.argmax(x, 1)
pred_retrieve = torch.zeros(pred.shape, dtype=torch.int32)
for i, val in enumerate(labelset):
pred_retrieve[pred == i] = labelset[i].cast('int32')
x = pred_retrieve
else:
raise NotImplementedError(self.upsample, ' is not implemented')
return [x]
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
**kwargs) -> List[Tensor]:
"""Forward function for testing, only ``pam_cam`` is used."""
seg_logits = self.forward(inputs)[0]
return seg_logits