[DEST] add DEST model (#2482)
## Motivation We are from NVIDIA and we have developed a simplified and inference-efficient transformer for dense prediction tasks. The method is based on SegFormer with hardware-friendly design choices, resulting in better accuracy and over 2x reduction in inference speed as compared to the baseline. We believe this model would be of particular interests to those who want to deploy an efficient vision transformer for production, and it is easily adaptable to other tasks. Therefore, we would like to contribute our method to mmsegmentation in order to benefit a larger audience. The paper was accepted to [Transformer for Vision workshop](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fsites.google.com%2Fview%2Ft4v-cvpr22%2Fpapers%3Fauthuser%3D0&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=XtSgPQrbVgHxt5L9XkXF%2BGWvc95haB3kKPcHnsVIF3M%3D&reserved=0) at CVPR 2022, here below are some resource links: Paper [https://arxiv.org/pdf/2204.13791.pdf](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Farxiv.org%2Fpdf%2F2204.13791.pdf&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=X%2FCVoa6PFA09EHfClES36QOa5NvbZu%2F6IDfBVwiYywU%3D&reserved=0) (Table 3 shows the semseg results) Code [https://github.com/NVIDIA/DL4AGX/tree/master/DEST](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2FNVIDIA%2FDL4AGX%2Ftree%2Fmaster%2FDEST&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=9DLQZpEq1cN75%2FDf%2FniUOOUFS1ABX8FEUH02O6isGVQ%3D&reserved=0) A webinar on its application [https://www.nvidia.com/en-us/on-demand/session/other2022-drivetraining/](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fwww.nvidia.com%2Fen-us%2Fon-demand%2Fsession%2Fother2022-drivetraining%2F&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=8jrBC%2Bp3jGxiaW4vtSfhh6GozC3tRqGNjNoALM%2FOYxs%3D&reserved=0) ## Modification Add backbone(smit.py) and head(dest_head.py) of DEST ## BC-breaking (Optional) N/A ## Use cases (Optional) N/A --------- Co-authored-by: MeowZheng <meowzheng@outlook.com>pull/2622/head
parent
486a40995e
commit
409caf8548
|
@ -2,6 +2,7 @@
|
|||
|
||||
In this file, we list the features with other licenses instead of Apache 2.0. Users should be careful about adopting these features in any commercial matters.
|
||||
|
||||
| Feature | Files | License |
|
||||
| :-------: | :-------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------: |
|
||||
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
|
||||
| Feature | Files | License |
|
||||
| :-------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------: |
|
||||
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
|
||||
| DEST | [mmseg/models/backbones/smit.py](https://github.com/open-mmlab/mmsegmentation/blob/master/projects/dest/models/smit.py) [mmseg/models/decode_heads/dest_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/projects/dest/models/dest_head.py) | [NVIDIA License](https://github.com/NVIDIA/DL4AGX/blob/master/DEST/LICENSE) |
|
||||
|
|
|
@ -182,6 +182,7 @@ Supported methods:
|
|||
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
|
||||
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
|
||||
- [x] [K-Net (NeurIPS'2021)](configs/knet)
|
||||
- [x] [DEST (CVPRW'2022)](projects/dest)
|
||||
|
||||
Supported datasets:
|
||||
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# DEST
|
||||
|
||||
[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)
|
||||
|
||||
## Description
|
||||
|
||||
Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles.
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.8.12
|
||||
- PyTorch 1.11
|
||||
- mmcv v1.7.0
|
||||
- Install [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) from source
|
||||
|
||||
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the mmsegmentaions directory so that Python can locate the configuration files in mmsegmentation.
|
||||
|
||||
### Dataset preparing
|
||||
|
||||
Preparing `cityscapes` dataset following this [Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets)
|
||||
|
||||
### Training commands
|
||||
|
||||
```shell
|
||||
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest
|
||||
```
|
||||
|
||||
To train on multiple GPUs, e.g. 8 GPUs, run the following command:
|
||||
|
||||
```shell
|
||||
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --launcher pytorch --gpus 8
|
||||
```
|
||||
|
||||
### Testing commands
|
||||
|
||||
```shell
|
||||
mim test mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --checkpoint ${CHECKPOINT_PATH} --eval mIoU
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
|
||||
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025-11f73f34.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025.log) |
|
||||
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358-0dd4e86e.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358.logmmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358.log) |
|
||||
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943-b06319ae.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943.log) |
|
||||
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800-ee4cec5c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800.log) |
|
||||
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155-3ca9f4fc.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155.log) |
|
||||
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411-e83819b5.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411.log) |
|
||||
|
||||
Note:
|
||||
|
||||
- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining.
|
||||
- Training of DEST is not very stable, which is sensitive to random seeds.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{YangDEST,
|
||||
title={Depth Estimation with Simplified Transformer},
|
||||
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
|
||||
journal={arXiv preprint arXiv:2204.13791},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Checklist
|
||||
|
||||
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
|
||||
|
||||
- [x] Finish the code
|
||||
|
||||
- [x] Basic docstrings & proper citation
|
||||
|
||||
- [x] Test-time correctness
|
||||
|
||||
- [x] A full README
|
||||
|
||||
- [x] Milestone 2: Indicates a successful model implementation.
|
||||
|
||||
- [x] Training-time correctness
|
||||
|
||||
- [ ] Milestone 3: Good to be a part of our core package!
|
||||
|
||||
- [ ] Type hints and docstrings
|
||||
|
||||
- [ ] Unit tests
|
||||
|
||||
- [ ] Code polishing
|
||||
|
||||
- [ ] Metafile.yml
|
||||
|
||||
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
|
||||
|
||||
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
|
|
@ -0,0 +1,50 @@
|
|||
# DEST
|
||||
|
||||
[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)
|
||||
|
||||
## Introduction
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
<a href="https://github.com/NVIDIA/DL4AGX/tree/master/DEST">Official Repo</a>
|
||||
|
||||
## Abstract
|
||||
|
||||
<!-- [ABSTRACT] -->
|
||||
|
||||
Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles.
|
||||
|
||||
<!-- [IMAGE] -->
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/76149310/219313665-49fa89ed-4973-4496-bb33-3256f107e82d.png" width="70%"/>
|
||||
</div>
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{YangDEST,
|
||||
title={Depth Estimation with Simplified Transformer},
|
||||
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
|
||||
journal={arXiv preprint arXiv:2204.13791},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
|
||||
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025-11f73f34.pth) |
|
||||
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358-0dd4e86e.pth) |
|
||||
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943-b06319ae.pth) |
|
||||
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800-ee4cec5c.pth) |
|
||||
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155-3ca9f4fc.pth) |
|
||||
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411-e83819b5.pth) |
|
||||
|
||||
Note:
|
||||
|
||||
- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining.
|
||||
- Training of DEST is not very stable, which is sensitive to random seeds.
|
|
@ -0,0 +1,37 @@
|
|||
# model settings
|
||||
embed_dims = [32, 64, 160, 256]
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='SimplifiedMixTransformer',
|
||||
in_channels=3,
|
||||
embed_dims=embed_dims,
|
||||
num_stages=4,
|
||||
num_layers=[2, 2, 2, 2],
|
||||
num_heads=[1, 2, 5, 8],
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
strides=[4, 2, 2, 2],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
qkv_bias=True,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
norm_cfg=norm_cfg),
|
||||
decode_head=dict(
|
||||
type='DESTHead',
|
||||
in_channels=[32, 64, 160, 256],
|
||||
in_index=[0, 1, 2, 3],
|
||||
channels=32,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
|
|
@ -0,0 +1,33 @@
|
|||
_base_ = [
|
||||
'./dest_simpatt-b0.py',
|
||||
'../../../configs/_base_/datasets/cityscapes_1024x1024.py',
|
||||
'../../../configs/_base_/default_runtime.py',
|
||||
'../../../configs/_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
|
||||
custom_imports = dict(imports=['projects.dest.models'])
|
||||
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=0.00006,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'pos_block': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.),
|
||||
'head': dict(lr_mult=10.)
|
||||
}))
|
||||
|
||||
lr_config = dict(
|
||||
_delete_=True,
|
||||
policy='poly',
|
||||
warmup='linear',
|
||||
warmup_iters=1500,
|
||||
warmup_ratio=1e-6,
|
||||
power=1.0,
|
||||
min_lr=0.0,
|
||||
by_epoch=False)
|
||||
|
||||
data = dict(samples_per_gpu=1, workers_per_gpu=1)
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
|
||||
|
||||
embed_dims = [64, 128, 250, 320]
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(embed_dims=embed_dims),
|
||||
decode_head=dict(in_channels=embed_dims, channels=64))
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
|
||||
|
||||
embed_dims = [64, 128, 250, 320]
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(embed_dims=embed_dims, num_layers=[3, 3, 6, 3]),
|
||||
decode_head=dict(in_channels=embed_dims, channels=64))
|
|
@ -0,0 +1,22 @@
|
|||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
|
||||
|
||||
embed_dims = [64, 128, 250, 320]
|
||||
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=0.00006,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'pos_block': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.),
|
||||
'head': dict(lr_mult=1.)
|
||||
}))
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(embed_dims=embed_dims, num_layers=[3, 6, 8, 3]),
|
||||
decode_head=dict(in_channels=embed_dims, channels=64))
|
|
@ -0,0 +1,22 @@
|
|||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
|
||||
|
||||
embed_dims = [64, 128, 250, 320]
|
||||
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=0.00006,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'pos_block': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.),
|
||||
'head': dict(lr_mult=1.)
|
||||
}))
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(embed_dims=embed_dims, num_layers=[3, 8, 12, 5]),
|
||||
decode_head=dict(in_channels=embed_dims, channels=64))
|
|
@ -0,0 +1,22 @@
|
|||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
|
||||
|
||||
embed_dims = [64, 128, 250, 320]
|
||||
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=0.00006,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'pos_block': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.),
|
||||
'head': dict(lr_mult=1.)
|
||||
}))
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(embed_dims=embed_dims, num_layers=[3, 10, 16, 5]),
|
||||
decode_head=dict(in_channels=embed_dims, channels=64))
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dest_head import DESTHead
|
||||
from .smit import SimplifiedMixTransformer
|
||||
|
||||
__all__ = ['SimplifiedMixTransformer', 'DESTHead']
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models import HEADS
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DESTHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self, interpolate_mode='bilinear', **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.interpolate_mode = interpolate_mode
|
||||
num_inputs = len(self.in_channels)
|
||||
assert num_inputs == len(self.in_index)
|
||||
self.fuse_in_channels = self.in_channels.copy()
|
||||
for i in range(num_inputs - 1):
|
||||
self.fuse_in_channels[i] += self.fuse_in_channels[i + 1]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(num_inputs):
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=self.in_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
self.fuse_convs = nn.ModuleList()
|
||||
for i in range(num_inputs):
|
||||
self.fuse_convs.append(
|
||||
ConvModule(
|
||||
in_channels=self.fuse_in_channels[i],
|
||||
out_channels=self.in_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
self.upsample = nn.ModuleList([
|
||||
nn.Sequential(nn.Upsample(scale_factor=2, mode=interpolate_mode))
|
||||
] * len(self.in_channels))
|
||||
|
||||
def forward(self, inputs):
|
||||
feat = None
|
||||
for idx in reversed(range(len(inputs))):
|
||||
x = self.convs[idx](inputs[idx])
|
||||
if idx != len(inputs) - 1:
|
||||
x = torch.concat([feat, x], dim=1)
|
||||
x = self.upsample[idx](x)
|
||||
feat = self.fuse_convs[idx](x)
|
||||
return self.cls_seg(feat)
|
|
@ -0,0 +1,557 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_init
|
||||
from mmcv.runner import BaseModule, ModuleList, Sequential
|
||||
from mmcv.utils import to_2tuple
|
||||
|
||||
from mmseg.models import BACKBONES
|
||||
from mmseg.models.utils.embed import AdaptivePadding
|
||||
|
||||
|
||||
class SimplifiedPatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
We use a conv layer to implement SimplifiedPatchEmbed.
|
||||
|
||||
Args:
|
||||
in_channels (int): The num of input channels. Default: 3
|
||||
embed_dims (int): The dimensions of embedding. Default: 768
|
||||
conv_type (str): The config dict for embedding
|
||||
conv layer type selection. Default: "Conv2d".
|
||||
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
||||
stride (int, optional): The slide stride of embedding conv.
|
||||
Default: None (Would be set as `kernel_size`).
|
||||
padding (int | tuple | string ): The padding length of
|
||||
embedding conv. When it is a string, it means the mode
|
||||
of adaptive padding, support "same" and "corner" now.
|
||||
Default: "corner".
|
||||
dilation (int): The dilation rate of embedding conv. Default: 1.
|
||||
bias (bool): Bias of embed conv. Default: True.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: None.
|
||||
input_size (int | tuple | None): The size of input, which will be
|
||||
used to calculate the out size. Only work when `dynamic_size`
|
||||
is False. Default: None.
|
||||
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=16,
|
||||
stride=None,
|
||||
padding='corner',
|
||||
dilation=1,
|
||||
bias=True,
|
||||
norm_cfg=None,
|
||||
input_size=None,
|
||||
init_cfg=None):
|
||||
super(SimplifiedPatchEmbed, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
if stride is None:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
if isinstance(padding, str):
|
||||
self.adap_padding = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
# disable the padding of conv
|
||||
padding = 0
|
||||
else:
|
||||
self.adap_padding = None
|
||||
padding = to_2tuple(padding)
|
||||
|
||||
self.projection = build_conv_layer(
|
||||
dict(type=conv_type),
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
if input_size:
|
||||
input_size = to_2tuple(input_size)
|
||||
# `init_out_size` would be used outside to
|
||||
# calculate the num_patches
|
||||
# when `use_abs_pos_embed` outside
|
||||
self.init_input_size = input_size
|
||||
if self.adap_padding:
|
||||
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
|
||||
input_h, input_w = input_size
|
||||
input_h = input_h + pad_h
|
||||
input_w = input_w + pad_w
|
||||
input_size = (input_h, input_w)
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
|
||||
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
||||
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
|
||||
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||
self.init_out_size = (h_out, w_out)
|
||||
else:
|
||||
self.init_input_size = None
|
||||
self.init_out_size = None
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
|
||||
|
||||
Returns:
|
||||
tuple: Contains merged results and its spatial shape.
|
||||
|
||||
- x (Tensor): Has shape (B, embed_dims, out_h * out_w)
|
||||
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||
(out_h, out_w).
|
||||
"""
|
||||
|
||||
if self.adap_padding:
|
||||
x = self.adap_padding(x)
|
||||
|
||||
x = self.projection(x)
|
||||
out_size = (x.shape[2], x.shape[3])
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
x = x.flatten(2)
|
||||
return x, out_size
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
|
||||
def __init__(self, dims):
|
||||
super(DWConv, self).__init__()
|
||||
self.dwconv = nn.Conv2d(dims, dims, 3, 1, 1, bias=True, groups=dims)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, C, N = x.shape
|
||||
x = x.reshape(B, C, H, W)
|
||||
x = self.dwconv(x)
|
||||
x = x.flatten(2)
|
||||
return x
|
||||
|
||||
|
||||
class MixFFN(nn.Module):
|
||||
"""An implementation of MixFFN of DEST.
|
||||
|
||||
The differences between MixFFN & FFN:
|
||||
1. Use 1X1 Conv to replace Linear layer.
|
||||
2. Introduce 3X3 Conv to encode positional information.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`. Defaults: 256.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='ReLU')
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
feedforward_channels,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
ffn_drop=0.,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
dropout_layer=None,
|
||||
init_cfg=None):
|
||||
super(MixFFN, self).__init__()
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.act_cfg = act_cfg
|
||||
activate = build_activation_layer(act_cfg)
|
||||
|
||||
in_channels = embed_dims
|
||||
fc1 = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
norm1 = build_norm_layer(norm_cfg, feedforward_channels)[1]
|
||||
self.dwconv = DWConv(feedforward_channels)
|
||||
norm2 = build_norm_layer(norm_cfg, feedforward_channels)[1]
|
||||
fc2 = nn.Conv1d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
drop = nn.Dropout(ffn_drop)
|
||||
pre_layers = [fc1, norm1]
|
||||
post_layers = [norm2, activate, drop, fc2, drop]
|
||||
self.pre_layers = Sequential(*pre_layers)
|
||||
self.post_layers = Sequential(*post_layers)
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
|
||||
def forward(self, x, hw_shape, identity):
|
||||
out = self.pre_layers(x)
|
||||
out = self.dwconv(out, hw_shape[0], hw_shape[1])
|
||||
out = self.post_layers(out)
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
class SimplifiedAttention(nn.Module):
|
||||
"""An implementation of Simplified Multi-head Attention of DEST.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default True.
|
||||
qk_scale (float, optional): scales for query and key. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
sr_ratio=1,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
dropout_layer=None,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dims // num_heads
|
||||
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.q = nn.Conv1d(embed_dims, embed_dims, 1, bias=qkv_bias)
|
||||
self.k = nn.Conv1d(embed_dims, embed_dims, 1, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Conv1d(embed_dims, embed_dims, 1)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = nn.Conv2d(
|
||||
embed_dims, embed_dims, kernel_size=sr_ratio, stride=sr_ratio)
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Conv1d):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
|
||||
def forward(self, x, hw_shape, identity):
|
||||
H, W = hw_shape
|
||||
B, C, N = x.shape
|
||||
q = self.q(x)
|
||||
q = q.reshape(B, self.num_heads, C // self.num_heads, N)
|
||||
q = q.permute(0, 1, 3, 2)
|
||||
|
||||
if self.sr_ratio > 1:
|
||||
x_ = x.reshape(B, C, H, W)
|
||||
x_ = self.sr(x_).reshape(B, C, -1)
|
||||
x_ = self.norm1(x_)
|
||||
k = self.k(x_).reshape(B, self.num_heads, C // self.num_heads, -1)
|
||||
else:
|
||||
k = self.k(x).reshape(B, self.num_heads, C // self.num_heads, -1)
|
||||
|
||||
v = torch.mean(x, 2, True).repeat(1, 1,
|
||||
self.num_heads).transpose(-2, -1)
|
||||
attn = (q @ k) * self.scale
|
||||
attn, _ = torch.max(attn, -1)
|
||||
out = (attn.transpose(-2, -1) @ v)
|
||||
out = out.transpose(-2, -1)
|
||||
out = self.proj(out)
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
class SimpliefiedTransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in DEST.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
after the feed forward layer. Default 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
qkv_bias (bool): enable bias for qkv if True.
|
||||
Default: True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
qk_scale (float, optional): scales for query and key. Default: None.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default:None.
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='SyncBN'),
|
||||
batch_first=True,
|
||||
qk_scale=None,
|
||||
sr_ratio=1,
|
||||
with_cp=False):
|
||||
super(SimpliefiedTransformerEncoderLayer, self).__init__()
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = SimplifiedAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
sr_ratio=sr_ratio,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = MixFFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class SimplifiedMixTransformer(BaseModule):
|
||||
"""The backbone of DEST.
|
||||
|
||||
This backbone is the implementation of `SegFormer: Simple and
|
||||
Efficient Design for Semantic Segmentation with
|
||||
Transformers <https://arxiv.org/abs/2105.15203>`_.
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (Sequence[int]): Embedding dimensions of each transformer
|
||||
encode layer. Default: [32, 64, 160, 256].
|
||||
num_stags (int): The num of stages. Default: 4.
|
||||
num_layers (Sequence[int]): The layer number of each transformer encode
|
||||
layer. Default: [3, 4, 6, 3].
|
||||
num_heads (Sequence[int]): The attention heads of each transformer
|
||||
encode layer. Default: [1, 2, 4, 8].
|
||||
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
|
||||
embedding. Default: [7, 3, 3, 3].
|
||||
strides (Sequence[int]): The stride of each overlapped patch embedding.
|
||||
Default: [4, 2, 2, 2].
|
||||
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
||||
transformer encode layer. Default: [8, 4, 2, 1].
|
||||
out_indices (Sequence[int] | int): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
mlp_ratios (Sequence[int]): ratios of mlp hidden dim to embedding dim.
|
||||
Default: [8, 8, 4, 4].
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[32, 64, 160, 256],
|
||||
num_stages=4,
|
||||
num_layers=[2, 2, 2, 2],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
strides=[4, 2, 2, 2],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
pretrained=None,
|
||||
init_cfg=None,
|
||||
with_cp=False):
|
||||
super(SimplifiedMixTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.num_stages = num_stages
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.patch_sizes = patch_sizes
|
||||
self.strides = strides
|
||||
self.sr_ratios = sr_ratios
|
||||
self.with_cp = with_cp
|
||||
assert num_stages == len(num_layers) == len(num_heads) \
|
||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
||||
] # stochastic num_layer decay rule
|
||||
|
||||
cur = 0
|
||||
self.layers = ModuleList()
|
||||
for i, num_layer in enumerate(num_layers):
|
||||
patch_embed = SimplifiedPatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims[i],
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding=patch_sizes[i] // 2,
|
||||
norm_cfg=norm_cfg)
|
||||
layer = ModuleList([
|
||||
SimpliefiedTransformerEncoderLayer(
|
||||
embed_dims=embed_dims[i],
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=mlp_ratios[i] * embed_dims[i],
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[cur + idx],
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||||
])
|
||||
in_channels = embed_dims[i]
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
|
||||
self.layers.append(ModuleList([patch_embed, layer, norm]))
|
||||
cur += num_layer
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, (H, W) = layer[0](x)
|
||||
for block in layer[1]:
|
||||
x = block(x, (H, W))
|
||||
x = layer[2](x)
|
||||
N, C, L = x.shape
|
||||
x = x.reshape(N, C, H, W)
|
||||
outs.append(x)
|
||||
return outs
|
Loading…
Reference in New Issue