[DEST] add DEST model (#2482)

## Motivation

We are from NVIDIA and we have developed a simplified and
inference-efficient transformer for dense prediction tasks. The method
is based on SegFormer with hardware-friendly design choices, resulting
in better accuracy and over 2x reduction in inference speed as compared
to the baseline. We believe this model would be of particular interests
to those who want to deploy an efficient vision transformer for
production, and it is easily adaptable to other tasks. Therefore, we
would like to contribute our method to mmsegmentation in order to
benefit a larger audience.

The paper was accepted to [Transformer for Vision
workshop](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fsites.google.com%2Fview%2Ft4v-cvpr22%2Fpapers%3Fauthuser%3D0&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=XtSgPQrbVgHxt5L9XkXF%2BGWvc95haB3kKPcHnsVIF3M%3D&reserved=0)
at CVPR 2022, here below are some resource links:
Paper
[https://arxiv.org/pdf/2204.13791.pdf](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Farxiv.org%2Fpdf%2F2204.13791.pdf&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=X%2FCVoa6PFA09EHfClES36QOa5NvbZu%2F6IDfBVwiYywU%3D&reserved=0)
(Table 3 shows the semseg results)
Code
[https://github.com/NVIDIA/DL4AGX/tree/master/DEST](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2FNVIDIA%2FDL4AGX%2Ftree%2Fmaster%2FDEST&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=9DLQZpEq1cN75%2FDf%2FniUOOUFS1ABX8FEUH02O6isGVQ%3D&reserved=0)
A webinar on its application
[https://www.nvidia.com/en-us/on-demand/session/other2022-drivetraining/](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fwww.nvidia.com%2Fen-us%2Fon-demand%2Fsession%2Fother2022-drivetraining%2F&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=8jrBC%2Bp3jGxiaW4vtSfhh6GozC3tRqGNjNoALM%2FOYxs%3D&reserved=0)

## Modification

Add backbone(smit.py) and head(dest_head.py) of DEST

## BC-breaking (Optional)

N/A

## Use cases (Optional)

N/A

---------

Co-authored-by: MeowZheng <meowzheng@outlook.com>
pull/2622/head
Boyin Zhang 2023-02-16 01:42:34 -08:00 committed by GitHub
parent 486a40995e
commit 409caf8548
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 924 additions and 3 deletions

View File

@ -2,6 +2,7 @@
In this file, we list the features with other licenses instead of Apache 2.0. Users should be careful about adopting these features in any commercial matters.
| Feature | Files | License |
| :-------: | :-------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------: |
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
| Feature | Files | License |
| :-------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------: |
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
| DEST | [mmseg/models/backbones/smit.py](https://github.com/open-mmlab/mmsegmentation/blob/master/projects/dest/models/smit.py) [mmseg/models/decode_heads/dest_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/projects/dest/models/dest_head.py) | [NVIDIA License](https://github.com/NVIDIA/DL4AGX/blob/master/DEST/LICENSE) |

View File

@ -182,6 +182,7 @@ Supported methods:
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
- [x] [K-Net (NeurIPS'2021)](configs/knet)
- [x] [DEST (CVPRW'2022)](projects/dest)
Supported datasets:

View File

@ -0,0 +1,99 @@
# DEST
[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)
## Description
Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles.
## Usage
### Prerequisites
- Python 3.8.12
- PyTorch 1.11
- mmcv v1.7.0
- Install [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) from source
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the mmsegmentaions directory so that Python can locate the configuration files in mmsegmentation.
### Dataset preparing
Preparing `cityscapes` dataset following this [Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets)
### Training commands
```shell
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest
```
To train on multiple GPUs, e.g. 8 GPUs, run the following command:
```shell
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --launcher pytorch --gpus 8
```
### Testing commands
```shell
mim test mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --checkpoint ${CHECKPOINT_PATH} --eval mIoU
```
## Results and models
### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025-11f73f34.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025.log) |
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358-0dd4e86e.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358.logmmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358.log) |
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943-b06319ae.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943.log) |
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800-ee4cec5c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800.log) |
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155-3ca9f4fc.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155.log) |
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411-e83819b5.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411.log) |
Note:
- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining.
- Training of DEST is not very stable, which is sensitive to random seeds.
## Citation
```bibtex
@article{YangDEST,
title={Depth Estimation with Simplified Transformer},
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
journal={arXiv preprint arXiv:2204.13791},
year={2022}
}
```
## Checklist
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
- [x] Finish the code
- [x] Basic docstrings & proper citation
- [x] Test-time correctness
- [x] A full README
- [x] Milestone 2: Indicates a successful model implementation.
- [x] Training-time correctness
- [ ] Milestone 3: Good to be a part of our core package!
- [ ] Type hints and docstrings
- [ ] Unit tests
- [ ] Code polishing
- [ ] Metafile.yml
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.

View File

@ -0,0 +1,50 @@
# DEST
[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)
## Introduction
<!-- [ALGORITHM] -->
<a href="https://github.com/NVIDIA/DL4AGX/tree/master/DEST">Official Repo</a>
## Abstract
<!-- [ABSTRACT] -->
Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/76149310/219313665-49fa89ed-4973-4496-bb33-3256f107e82d.png" width="70%"/>
</div>
## Citation
```bibtex
@article{YangDEST,
title={Depth Estimation with Simplified Transformer},
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
journal={arXiv preprint arXiv:2204.13791},
year={2022}
}
```
## Results and models
### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025-11f73f34.pth) |
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358-0dd4e86e.pth) |
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943-b06319ae.pth) |
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800-ee4cec5c.pth) |
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155-3ca9f4fc.pth) |
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411-e83819b5.pth) |
Note:
- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining.
- Training of DEST is not very stable, which is sensitive to random seeds.

View File

@ -0,0 +1,37 @@
# model settings
embed_dims = [32, 64, 160, 256]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='SimplifiedMixTransformer',
in_channels=3,
embed_dims=embed_dims,
num_stages=4,
num_layers=[2, 2, 2, 2],
num_heads=[1, 2, 5, 8],
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_cfg=norm_cfg),
decode_head=dict(
type='DESTHead',
in_channels=[32, 64, 160, 256],
in_index=[0, 1, 2, 3],
channels=32,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))

View File

@ -0,0 +1,33 @@
_base_ = [
'./dest_simpatt-b0.py',
'../../../configs/_base_/datasets/cityscapes_1024x1024.py',
'../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_160k.py'
]
custom_imports = dict(imports=['projects.dest.models'])
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
data = dict(samples_per_gpu=1, workers_per_gpu=1)

View File

@ -0,0 +1,9 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
embed_dims = [64, 128, 250, 320]
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims),
decode_head=dict(in_channels=embed_dims, channels=64))

View File

@ -0,0 +1,9 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
embed_dims = [64, 128, 250, 320]
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims, num_layers=[3, 3, 6, 3]),
decode_head=dict(in_channels=embed_dims, channels=64))

View File

@ -0,0 +1,22 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
embed_dims = [64, 128, 250, 320]
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=1.)
}))
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims, num_layers=[3, 6, 8, 3]),
decode_head=dict(in_channels=embed_dims, channels=64))

View File

@ -0,0 +1,22 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
embed_dims = [64, 128, 250, 320]
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=1.)
}))
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims, num_layers=[3, 8, 12, 5]),
decode_head=dict(in_channels=embed_dims, channels=64))

View File

@ -0,0 +1,22 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
embed_dims = [64, 128, 250, 320]
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=1.)
}))
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims, num_layers=[3, 10, 16, 5]),
decode_head=dict(in_channels=embed_dims, channels=64))

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dest_head import DESTHead
from .smit import SimplifiedMixTransformer
__all__ = ['SimplifiedMixTransformer', 'DESTHead']

View File

@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.models import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
@HEADS.register_module()
class DESTHead(BaseDecodeHead):
def __init__(self, interpolate_mode='bilinear', **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
self.interpolate_mode = interpolate_mode
num_inputs = len(self.in_channels)
assert num_inputs == len(self.in_index)
self.fuse_in_channels = self.in_channels.copy()
for i in range(num_inputs - 1):
self.fuse_in_channels[i] += self.fuse_in_channels[i + 1]
self.convs = nn.ModuleList()
for i in range(num_inputs):
self.convs.append(
ConvModule(
in_channels=self.in_channels[i],
out_channels=self.in_channels[i],
kernel_size=1,
stride=1,
act_cfg=self.act_cfg))
self.fuse_convs = nn.ModuleList()
for i in range(num_inputs):
self.fuse_convs.append(
ConvModule(
in_channels=self.fuse_in_channels[i],
out_channels=self.in_channels[i],
kernel_size=3,
stride=1,
padding=1,
act_cfg=self.act_cfg))
self.upsample = nn.ModuleList([
nn.Sequential(nn.Upsample(scale_factor=2, mode=interpolate_mode))
] * len(self.in_channels))
def forward(self, inputs):
feat = None
for idx in reversed(range(len(inputs))):
x = self.convs[idx](inputs[idx])
if idx != len(inputs) - 1:
x = torch.concat([feat, x], dim=1)
x = self.upsample[idx](x)
feat = self.fuse_convs[idx](x)
return self.cls_seg(feat)

View File

@ -0,0 +1,557 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.utils.weight_init import trunc_normal_init
from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.utils import to_2tuple
from mmseg.models import BACKBONES
from mmseg.models.utils.embed import AdaptivePadding
class SimplifiedPatchEmbed(BaseModule):
"""Image to Patch Embedding.
We use a conv layer to implement SimplifiedPatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (str): The config dict for embedding
conv layer type selection. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int, optional): The slide stride of embedding conv.
Default: None (Would be set as `kernel_size`).
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only work when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type='Conv2d',
kernel_size=16,
stride=None,
padding='corner',
dilation=1,
bias=True,
norm_cfg=None,
input_size=None,
init_cfg=None):
super(SimplifiedPatchEmbed, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adap_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of conv
padding = 0
else:
self.adap_padding = None
padding = to_2tuple(padding)
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
if input_size:
input_size = to_2tuple(input_size)
# `init_out_size` would be used outside to
# calculate the num_patches
# when `use_abs_pos_embed` outside
self.init_input_size = input_size
if self.adap_padding:
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
input_h, input_w = input_size
input_h = input_h + pad_h
input_w = input_w + pad_w
input_size = (input_h, input_w)
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
self.init_out_size = (h_out, w_out)
else:
self.init_input_size = None
self.init_out_size = None
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, embed_dims, out_h * out_w)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adap_padding:
x = self.adap_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
if self.norm is not None:
x = self.norm(x)
x = x.flatten(2)
return x, out_size
class DWConv(nn.Module):
def __init__(self, dims):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dims, dims, 3, 1, 1, bias=True, groups=dims)
def forward(self, x, H, W):
B, C, N = x.shape
x = x.reshape(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2)
return x
class MixFFN(nn.Module):
"""An implementation of MixFFN of DEST.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='ReLU'),
ffn_drop=0.,
norm_cfg=dict(type='SyncBN', requires_grad=True),
dropout_layer=None,
init_cfg=None):
super(MixFFN, self).__init__()
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
activate = build_activation_layer(act_cfg)
in_channels = embed_dims
fc1 = nn.Conv1d(
in_channels=in_channels,
out_channels=feedforward_channels,
kernel_size=1,
stride=1)
norm1 = build_norm_layer(norm_cfg, feedforward_channels)[1]
self.dwconv = DWConv(feedforward_channels)
norm2 = build_norm_layer(norm_cfg, feedforward_channels)[1]
fc2 = nn.Conv1d(
in_channels=feedforward_channels,
out_channels=in_channels,
kernel_size=1,
stride=1)
drop = nn.Dropout(ffn_drop)
pre_layers = [fc1, norm1]
post_layers = [norm2, activate, drop, fc2, drop]
self.pre_layers = Sequential(*pre_layers)
self.post_layers = Sequential(*post_layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv1d):
trunc_normal_init(m, std=.02, bias=0.)
def forward(self, x, hw_shape, identity):
out = self.pre_layers(x)
out = self.dwconv(out, hw_shape[0], hw_shape[1])
out = self.post_layers(out)
return identity + self.dropout_layer(out)
class SimplifiedAttention(nn.Module):
"""An implementation of Simplified Multi-head Attention of DEST.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
qkv_bias (bool): enable bias for qkv if True. Default True.
qk_scale (float, optional): scales for query and key. Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1,
qkv_bias=False,
qk_scale=None,
dropout_layer=None,
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.embed_dims = embed_dims
self.num_heads = num_heads
head_dim = embed_dims // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Conv1d(embed_dims, embed_dims, 1, bias=qkv_bias)
self.k = nn.Conv1d(embed_dims, embed_dims, 1, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv1d(embed_dims, embed_dims, 1)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(
embed_dims, embed_dims, kernel_size=sr_ratio, stride=sr_ratio)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv1d):
trunc_normal_init(m, std=.02, bias=0.)
def forward(self, x, hw_shape, identity):
H, W = hw_shape
B, C, N = x.shape
q = self.q(x)
q = q.reshape(B, self.num_heads, C // self.num_heads, N)
q = q.permute(0, 1, 3, 2)
if self.sr_ratio > 1:
x_ = x.reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1)
x_ = self.norm1(x_)
k = self.k(x_).reshape(B, self.num_heads, C // self.num_heads, -1)
else:
k = self.k(x).reshape(B, self.num_heads, C // self.num_heads, -1)
v = torch.mean(x, 2, True).repeat(1, 1,
self.num_heads).transpose(-2, -1)
attn = (q @ k) * self.scale
attn, _ = torch.max(attn, -1)
out = (attn.transpose(-2, -1) @ v)
out = out.transpose(-2, -1)
out = self.proj(out)
return identity + self.dropout_layer(out)
class SimpliefiedTransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in DEST.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed.
after the feed forward layer. Default 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.0.
qkv_bias (bool): enable bias for qkv if True.
Default: True.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: False.
qk_scale (float, optional): scales for query and key. Default: None.
init_cfg (dict, optional): Initialization config dict.
Default:None.
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='SyncBN'),
batch_first=True,
qk_scale=None,
sr_ratio=1,
with_cp=False):
super(SimpliefiedTransformerEncoderLayer, self).__init__()
# The ret[0] of build_norm_layer is norm name.
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = SimplifiedAttention(
embed_dims=embed_dims,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
sr_ratio=sr_ratio,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = MixFFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, hw_shape):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)
return x
@BACKBONES.register_module()
class SimplifiedMixTransformer(BaseModule):
"""The backbone of DEST.
This backbone is the implementation of `SegFormer: Simple and
Efficient Design for Semantic Segmentation with
Transformers <https://arxiv.org/abs/2105.15203>`_.
Args:
in_channels (int): Number of input channels. Default: 3.
embed_dims (Sequence[int]): Embedding dimensions of each transformer
encode layer. Default: [32, 64, 160, 256].
num_stags (int): The num of stages. Default: 4.
num_layers (Sequence[int]): The layer number of each transformer encode
layer. Default: [3, 4, 6, 3].
num_heads (Sequence[int]): The attention heads of each transformer
encode layer. Default: [1, 2, 4, 8].
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
embedding. Default: [7, 3, 3, 3].
strides (Sequence[int]): The stride of each overlapped patch embedding.
Default: [4, 2, 2, 2].
sr_ratios (Sequence[int]): The spatial reduction rate of each
transformer encode layer. Default: [8, 4, 2, 1].
out_indices (Sequence[int] | int): Output from which stages.
Default: (0, 1, 2, 3).
mlp_ratios (Sequence[int]): ratios of mlp hidden dim to embedding dim.
Default: [8, 8, 4, 4].
qkv_bias (bool): Enable bias for qkv if True. Default: True.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels=3,
embed_dims=[32, 64, 160, 256],
num_stages=4,
num_layers=[2, 2, 2, 2],
num_heads=[1, 2, 4, 8],
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='SyncBN', requires_grad=True),
pretrained=None,
init_cfg=None,
with_cp=False):
super(SimplifiedMixTransformer, self).__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.embed_dims = embed_dims
self.num_stages = num_stages
self.num_layers = num_layers
self.num_heads = num_heads
self.patch_sizes = patch_sizes
self.strides = strides
self.sr_ratios = sr_ratios
self.with_cp = with_cp
assert num_stages == len(num_layers) == len(num_heads) \
== len(patch_sizes) == len(strides) == len(sr_ratios)
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
# transformer encoder
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
] # stochastic num_layer decay rule
cur = 0
self.layers = ModuleList()
for i, num_layer in enumerate(num_layers):
patch_embed = SimplifiedPatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims[i],
kernel_size=patch_sizes[i],
stride=strides[i],
padding=patch_sizes[i] // 2,
norm_cfg=norm_cfg)
layer = ModuleList([
SimpliefiedTransformerEncoderLayer(
embed_dims=embed_dims[i],
num_heads=num_heads[i],
feedforward_channels=mlp_ratios[i] * embed_dims[i],
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[cur + idx],
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
])
in_channels = embed_dims[i]
# The ret[0] of build_norm_layer is norm name.
norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
self.layers.append(ModuleList([patch_embed, layer, norm]))
cur += num_layer
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
outs = []
for i, layer in enumerate(self.layers):
x, (H, W) = layer[0](x)
for block in layer[1]:
x = block(x, (H, W))
x = layer[2](x)
N, C, L = x.shape
x = x.reshape(N, C, H, W)
outs.append(x)
return outs