diff --git a/README.md b/README.md
index 378b2ab08..1748fc1eb 100644
--- a/README.md
+++ b/README.md
@@ -140,6 +140,7 @@ Supported methods:
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
- [x] [K-Net (NeurIPS'2021)](configs/knet)
- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer)
+- [x] [Mask2Former (CVPR'2022)](configs/mask2former)
Supported datasets:
diff --git a/README_zh-CN.md b/README_zh-CN.md
index b66977035..fcb8dcf78 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -135,6 +135,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
- [x] [K-Net (NeurIPS'2021)](configs/knet)
- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer)
+- [x] [Mask2Former (CVPR'2022)](configs/mask2former)
已支持的数据集:
diff --git a/configs/mask2former/README.md b/configs/mask2former/README.md
new file mode 100644
index 000000000..8881b0d66
--- /dev/null
+++ b/configs/mask2former/README.md
@@ -0,0 +1,72 @@
+# Mask2Former
+
+[Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+Image segmentation is about grouping pixels with different semantics, e.g., category or instance membership, where each choice of semantics defines a task. While only the semantics of each task differ, current research focuses on designing specialized architectures for each task. We present Masked-attention Mask Transformer (Mask2Former), a new architecture capable of addressing any image segmentation task (panoptic, instance or semantic). Its key components include masked attention, which extracts localized features by constraining cross-attention within predicted mask regions. In addition to reducing the research effort by at least three times, it outperforms the best specialized architectures by a significant margin on four popular datasets. Most notably, Mask2Former sets a new state-of-the-art for panoptic segmentation (57.8 PQ on COCO), instance segmentation (50.1 AP on COCO) and semantic segmentation (57.7 mIoU on ADE20K).
+
+```bibtex
+@inproceedings{cheng2021mask2former,
+ title={Masked-attention Mask Transformer for Universal Image Segmentation},
+ author={Bowen Cheng and Ishan Misra and Alexander G. Schwing and Alexander Kirillov and Rohit Girdhar},
+ journal={CVPR},
+ year={2022}
+}
+@inproceedings{cheng2021maskformer,
+ title={Per-Pixel Classification is Not All You Need for Semantic Segmentation},
+ author={Bowen Cheng and Alexander G. Schwing and Alexander Kirillov},
+ journal={NeurIPS},
+ year={2021}
+}
+```
+
+### Usage
+
+- Mask2Former model needs to install [MMDetection](https://github.com/open-mmlab/mmdetection) first.
+
+```shell
+pip install "mmdet>=3.0.0rc4"
+```
+
+## Results and models
+
+### Cityscapes
+
+| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
+| ----------- | -------------- | --------- | ------- | -------: | -------------- | ----- | ------------: | -----------------------------------------------------------------------------------------------------------------------------------------------------------: | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Mask2Former | R-50-D32 | 512x1024 | 90000 | 5806 | 9.17 | 80.44 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_r50_8xb2-90k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r50_8xb2-90k_cityscapes-512x1024/mask2former_r50_8xb2-90k_cityscapes-512x1024_20221202_140802-2ff5ffa0.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r50_8xb2-90k_cityscapes-512x1024/mask2former_r50_8xb2-90k_cityscapes-512x1024_20221202_140802.json) |
+| Mask2Former | R-101-D32 | 512x1024 | 90000 | 6971 | 7.11 | 80.80 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_r101_8xb2-90k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r101_8xb2-90k_cityscapes-512x1024/mask2former_r101_8xb2-90k_cityscapes-512x1024_20221130_031628-8ad528ea.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r101_8xb2-90k_cityscapes-512x1024/mask2former_r101_8xb2-90k_cityscapes-512x1024_20221130_031628.json)) |
+| Mask2Former | Swin-T | 512x1024 | 90000 | 6511 | 7.18 | 81.71 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-t_8xb2-90k_cityscapes-512x1024/mask2former_swin-t_8xb2-90k_cityscapes-512x1024_20221127_144501-290b34af.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-t_8xb2-90k_cityscapes-512x1024/mask2former_swin-t_8xb2-90k_cityscapes-512x1024_20221127_144501.json)) |
+| Mask2Former | Swin-S | 512x1024 | 90000 | 8282 | 5.57 | 82.57 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-s_8xb2-90k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-s_8xb2-90k_cityscapes-512x1024/mask2former_swin-s_8xb2-90k_cityscapes-512x1024_20221127_143802-7c98854a.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-s_8xb2-90k_cityscapes-512x1024/mask2former_swin-s_8xb2-90k_cityscapes-512x1024_20221127_143802.json)) |
+| Mask2Former | Swin-B (in22k) | 512x1024 | 90000 | 11152 | 4.32 | 83.52 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221203_045030-59a4379a.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221203_045030.json)) |
+| Mask2Former | Swin-L (in22k) | 512x1024 | 90000 | 16207 | 2.86 | 83.65 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-dc2c2ddd.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901.json)) |
+
+### ADE20K
+
+| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
+| ----------- | -------------- | --------- | ------- | -------: | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------------------------------------------------------: | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Mask2Former | R-50-D32 | 512x512 | 160000 | 3385 | 26.59 | 47.87 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055-4c62652d.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055.json)) |
+| Mask2Former | R-101-D32 | 512x512 | 160000 | 4190 | 22.97 | 48.60 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_r101_8xb2-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r101_8xb2-160k_ade20k-512x512/mask2former_r101_8xb2-160k_ade20k-512x512_20221203_233905-b1169bc0.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r101_8xb2-160k_ade20k-512x512/mask2former_r101_8xb2-160k_ade20k-512x512_20221203_233905.json)) |
+| Mask2Former | Swin-T | 512x512 | 160000 | 3826 | 23.82 | 48.66 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-t_8xb2-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-t_8xb2-160k_ade20k-512x512/mask2former_swin-t_8xb2-160k_ade20k-512x512_20221203_234230-4341520b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-t_8xb2-160k_ade20k-512x512/mask2former_swin-t_8xb2-160k_ade20k-512x512_20221203_234230.json)) |
+| Mask2Former | Swin-S | 512x512 | 160000 | 5034 | 19.69 | 51.24 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-s_8xb2-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-s_8xb2-160k_ade20k-512x512/mask2former_swin-s_8xb2-160k_ade20k-512x512_20221204_143905-ab263c11.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-s_8xb2-160k_ade20k-512x512/mask2former_swin-s_8xb2-160k_ade20k-512x512_20221204_143905.json)) |
+| Mask2Former | Swin-B | 640x640 | 160000 | 5795 | 12.48 | 52.44 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640_20221129_125118-35e3a2c7.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640_20221129_125118.json)) |
+| Mask2Former | Swin-B (in22k) | 640x640 | 160000 | 5795 | 12.43 | 53.90 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640_20221203_235230-622e093b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640_20221203_235230.json)) |
+| Mask2Former | Swin-L (in22k) | 640x640 | 160000 | 9077 | 8.81 | 56.01 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640_20221203_235933-5cc76a78.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640_20221203_235933.json)) |
+
+Note:
+
+- All experiments of Mask2Former are implemented with 8 A100 GPUs with 2 samplers per GPU.
+- As mentioned at [the official repo](https://github.com/facebookresearch/Mask2Former/issues/5), the results of Mask2Former are relatively not stable, the result of Mask2Former(swin-s) on ADE20K dataset in the table is the medium result obtained by training 5 times following the suggestion of the author.
+- The ResNet backbones utilized in MaskFormer models are standard `ResNet` rather than `ResNetV1c`.
+- Test time augmentation is not supported in MMSegmentation 1.x version yet, we would add "ms+flip" results as soon as possible.
diff --git a/configs/mask2former/mask2former.yml b/configs/mask2former/mask2former.yml
new file mode 100644
index 000000000..78655fc52
--- /dev/null
+++ b/configs/mask2former/mask2former.yml
@@ -0,0 +1,290 @@
+Collections:
+- Name: Mask2Former
+ Metadata:
+ Training Data:
+ - Usage
+ - Cityscapes
+ - ADE20K
+ Paper:
+ URL: https://arxiv.org/abs/2112.01527
+ Title: Masked-attention Mask Transformer for Universal Image Segmentation
+ README: configs/mask2former/README.md
+ Code:
+ URL: https://github.com/open-mmlab/mmdetection/blob/3.x/mmdet/models/dense_heads/mask2former_head.py
+ Version: 3.x
+ Converted From:
+ Code: https://github.com/facebookresearch/Mask2Former
+Models:
+- Name: mask2former_r50_8xb2-90k_cityscapes-512x1024
+ In Collection: Mask2Former
+ Metadata:
+ backbone: R-50-D32
+ crop size: (512,1024)
+ lr schd: 90000
+ inference time (ms/im):
+ - value: 109.05
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,1024)
+ Training Memory (GB): 5806.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: Cityscapes
+ Metrics:
+ mIoU: 80.44
+ Config: configs/mask2former/mask2former_r50_8xb2-90k_cityscapes-512x1024.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r50_8xb2-90k_cityscapes-512x1024/mask2former_r50_8xb2-90k_cityscapes-512x1024_20221202_140802-2ff5ffa0.pth
+- Name: mask2former_r101_8xb2-90k_cityscapes-512x1024
+ In Collection: Mask2Former
+ Metadata:
+ backbone: R-101-D32
+ crop size: (512,1024)
+ lr schd: 90000
+ inference time (ms/im):
+ - value: 140.65
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,1024)
+ Training Memory (GB): 6971.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: Cityscapes
+ Metrics:
+ mIoU: 80.8
+ Config: configs/mask2former/mask2former_r101_8xb2-90k_cityscapes-512x1024.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r101_8xb2-90k_cityscapes-512x1024/mask2former_r101_8xb2-90k_cityscapes-512x1024_20221130_031628-8ad528ea.pth
+- Name: mask2former_swin-t_8xb2-90k_cityscapes-512x1024
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-T
+ crop size: (512,1024)
+ lr schd: 90000
+ inference time (ms/im):
+ - value: 139.28
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,1024)
+ Training Memory (GB): 6511.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: Cityscapes
+ Metrics:
+ mIoU: 81.71
+ Config: configs/mask2former/mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-t_8xb2-90k_cityscapes-512x1024/mask2former_swin-t_8xb2-90k_cityscapes-512x1024_20221127_144501-290b34af.pth
+- Name: mask2former_swin-s_8xb2-90k_cityscapes-512x1024
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-S
+ crop size: (512,1024)
+ lr schd: 90000
+ inference time (ms/im):
+ - value: 179.53
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,1024)
+ Training Memory (GB): 8282.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: Cityscapes
+ Metrics:
+ mIoU: 82.57
+ Config: configs/mask2former/mask2former_swin-s_8xb2-90k_cityscapes-512x1024.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-s_8xb2-90k_cityscapes-512x1024/mask2former_swin-s_8xb2-90k_cityscapes-512x1024_20221127_143802-7c98854a.pth
+- Name: mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-B (in22k)
+ crop size: (512,1024)
+ lr schd: 90000
+ inference time (ms/im):
+ - value: 231.48
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,1024)
+ Training Memory (GB): 11152.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: Cityscapes
+ Metrics:
+ mIoU: 83.52
+ Config: configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221203_045030-59a4379a.pth
+- Name: mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-L (in22k)
+ crop size: (512,1024)
+ lr schd: 90000
+ inference time (ms/im):
+ - value: 349.65
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,1024)
+ Training Memory (GB): 16207.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: Cityscapes
+ Metrics:
+ mIoU: 83.65
+ Config: configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-dc2c2ddd.pth
+- Name: mask2former_r50_8xb2-160k_ade20k-512x512
+ In Collection: Mask2Former
+ Metadata:
+ backbone: R-50-D32
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 37.61
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 3385.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 47.87
+ Config: configs/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055-4c62652d.pth
+- Name: mask2former_r101_8xb2-160k_ade20k-512x512
+ In Collection: Mask2Former
+ Metadata:
+ backbone: R-101-D32
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 43.54
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 4190.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 48.6
+ Config: configs/mask2former/mask2former_r101_8xb2-160k_ade20k-512x512.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_r101_8xb2-160k_ade20k-512x512/mask2former_r101_8xb2-160k_ade20k-512x512_20221203_233905-b1169bc0.pth
+- Name: mask2former_swin-t_8xb2-160k_ade20k-512x512
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-T
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 41.98
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 3826.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 48.66
+ Config: configs/mask2former/mask2former_swin-t_8xb2-160k_ade20k-512x512.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-t_8xb2-160k_ade20k-512x512/mask2former_swin-t_8xb2-160k_ade20k-512x512_20221203_234230-4341520b.pth
+- Name: mask2former_swin-s_8xb2-160k_ade20k-512x512
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-S
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 50.79
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 5034.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 51.24
+ Config: configs/mask2former/mask2former_swin-s_8xb2-160k_ade20k-512x512.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-s_8xb2-160k_ade20k-512x512/mask2former_swin-s_8xb2-160k_ade20k-512x512_20221204_143905-ab263c11.pth
+- Name: mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-B
+ crop size: (640,640)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 80.13
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (640,640)
+ Training Memory (GB): 5795.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 52.44
+ Config: configs/mask2former/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640_20221129_125118-35e3a2c7.pth
+- Name: mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-B (in22k)
+ crop size: (640,640)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 80.45
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (640,640)
+ Training Memory (GB): 5795.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 53.9
+ Config: configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640_20221203_235230-622e093b.pth
+- Name: mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640
+ In Collection: Mask2Former
+ Metadata:
+ backbone: Swin-L (in22k)
+ crop size: (640,640)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 113.51
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (640,640)
+ Training Memory (GB): 9077.0
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 56.01
+ Config: configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640_20221203_235933-5cc76a78.pth
diff --git a/configs/mask2former/mask2former_r101_8xb2-160k_ade20k-512x512.py b/configs/mask2former/mask2former_r101_8xb2-160k_ade20k-512x512.py
new file mode 100644
index 000000000..48f6c12d1
--- /dev/null
+++ b/configs/mask2former/mask2former_r101_8xb2-160k_ade20k-512x512.py
@@ -0,0 +1,7 @@
+_base_ = ['./mask2former_r50_8xb2-160k_ade20k-512x512.py']
+
+model = dict(
+ backbone=dict(
+ depth=101,
+ init_cfg=dict(type='Pretrained',
+ checkpoint='torchvision://resnet101')))
diff --git a/configs/mask2former/mask2former_r101_8xb2-90k_cityscapes-512x1024.py b/configs/mask2former/mask2former_r101_8xb2-90k_cityscapes-512x1024.py
new file mode 100644
index 000000000..275a7dab5
--- /dev/null
+++ b/configs/mask2former/mask2former_r101_8xb2-90k_cityscapes-512x1024.py
@@ -0,0 +1,7 @@
+_base_ = ['./mask2former_r50_8xb2-90k_cityscapes-512x1024.py']
+
+model = dict(
+ backbone=dict(
+ depth=101,
+ init_cfg=dict(type='Pretrained',
+ checkpoint='torchvision://resnet101')))
diff --git a/configs/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512.py b/configs/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512.py
new file mode 100644
index 000000000..598cabfb6
--- /dev/null
+++ b/configs/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512.py
@@ -0,0 +1,207 @@
+_base_ = ['../_base_/default_runtime.py', '../_base_/datasets/ade20k.py']
+
+custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
+
+crop_size = (512, 512)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255,
+ size=crop_size,
+ test_cfg=dict(size_divisor=32))
+num_classes = 150
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ deep_stem=False,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='SyncBN', requires_grad=False),
+ style='pytorch',
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
+ decode_head=dict(
+ type='Mask2FormerHead',
+ in_channels=[256, 512, 1024, 2048],
+ strides=[4, 8, 16, 32],
+ feat_channels=256,
+ out_channels=256,
+ num_classes=num_classes,
+ num_queries=100,
+ num_transformer_feat_level=3,
+ align_corners=False,
+ pixel_decoder=dict(
+ type='mmdet.MSDeformAttnPixelDecoder',
+ num_outs=3,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ encoder=dict(
+ type='mmdet.DetrTransformerEncoder',
+ num_layers=6,
+ transformerlayers=dict(
+ type='mmdet.BaseTransformerLayer',
+ attn_cfgs=dict(
+ type='mmdet.MultiScaleDeformableAttention',
+ embed_dims=256,
+ num_heads=8,
+ num_levels=3,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.0,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None),
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.0,
+ act_cfg=dict(type='ReLU', inplace=True)),
+ operation_order=('self_attn', 'norm', 'ffn', 'norm')),
+ init_cfg=None),
+ positional_encoding=dict(
+ type='mmdet.SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ init_cfg=None),
+ enforce_decoder_input_project=False,
+ positional_encoding=dict(
+ type='mmdet.SinePositionalEncoding', num_feats=128,
+ normalize=True),
+ transformer_decoder=dict(
+ type='mmdet.DetrTransformerDecoder',
+ return_intermediate=True,
+ num_layers=9,
+ transformerlayers=dict(
+ type='mmdet.DetrTransformerDecoderLayer',
+ attn_cfgs=dict(
+ type='mmdet.MultiheadAttention',
+ embed_dims=256,
+ num_heads=8,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ dropout_layer=None,
+ batch_first=False),
+ ffn_cfgs=dict(
+ embed_dims=256,
+ feedforward_channels=2048,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.0,
+ dropout_layer=None,
+ add_identity=True),
+ feedforward_channels=2048,
+ operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
+ 'ffn', 'norm')),
+ init_cfg=None),
+ loss_cls=dict(
+ type='mmdet.CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=2.0,
+ reduction='mean',
+ class_weight=[1.0] * num_classes + [0.1]),
+ loss_mask=dict(
+ type='mmdet.CrossEntropyLoss',
+ use_sigmoid=True,
+ reduction='mean',
+ loss_weight=5.0),
+ loss_dice=dict(
+ type='mmdet.DiceLoss',
+ use_sigmoid=True,
+ activate=True,
+ reduction='mean',
+ naive_dice=True,
+ eps=1.0,
+ loss_weight=5.0),
+ train_cfg=dict(
+ num_points=12544,
+ oversample_ratio=3.0,
+ importance_sample_ratio=0.75,
+ assigner=dict(
+ type='mmdet.HungarianAssigner',
+ match_costs=[
+ dict(type='mmdet.ClassificationCost', weight=2.0),
+ dict(
+ type='mmdet.CrossEntropyLossCost',
+ weight=5.0,
+ use_sigmoid=True),
+ dict(
+ type='mmdet.DiceCost',
+ weight=5.0,
+ pred_act=True,
+ eps=1.0)
+ ]),
+ sampler=dict(type='mmdet.MaskPseudoSampler'))),
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+
+# dataset config
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(
+ type='RandomChoiceResize',
+ scales=[int(512 * x * 0.1) for x in range(5, 21)],
+ resize_type='ResizeShortestEdge',
+ max_size=2048),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(batch_size=2, dataset=dict(pipeline=train_pipeline))
+
+# optimizer
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+optimizer = dict(
+ type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=optimizer,
+ clip_grad=dict(max_norm=0.01, norm_type=2),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi,
+ },
+ norm_decay_mult=0.0))
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ eta_min=0,
+ power=0.9,
+ begin=0,
+ end=160000,
+ by_epoch=False)
+]
+
+# training schedule for 160k
+train_cfg = dict(
+ type='IterBasedTrainLoop', max_iters=160000, val_interval=5000)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook', by_epoch=False, interval=5000,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+# Default setting for scaling LR automatically
+# - `enable` means enable scaling LR automatically
+# or not by default.
+# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
+auto_scale_lr = dict(enable=False, base_batch_size=16)
diff --git a/configs/mask2former/mask2former_r50_8xb2-90k_cityscapes-512x1024.py b/configs/mask2former/mask2former_r50_8xb2-90k_cityscapes-512x1024.py
new file mode 100644
index 000000000..f92dda98a
--- /dev/null
+++ b/configs/mask2former/mask2former_r50_8xb2-90k_cityscapes-512x1024.py
@@ -0,0 +1,206 @@
+_base_ = ['../_base_/default_runtime.py', '../_base_/datasets/cityscapes.py']
+
+custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
+
+crop_size = (512, 1024)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255,
+ size=crop_size,
+ test_cfg=dict(size_divisor=32))
+num_classes = 19
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ deep_stem=False,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='SyncBN', requires_grad=False),
+ style='pytorch',
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
+ decode_head=dict(
+ type='Mask2FormerHead',
+ in_channels=[256, 512, 1024, 2048],
+ strides=[4, 8, 16, 32],
+ feat_channels=256,
+ out_channels=256,
+ num_classes=num_classes,
+ num_queries=100,
+ num_transformer_feat_level=3,
+ align_corners=False,
+ pixel_decoder=dict(
+ type='mmdet.MSDeformAttnPixelDecoder',
+ num_outs=3,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ encoder=dict(
+ type='mmdet.DetrTransformerEncoder',
+ num_layers=6,
+ transformerlayers=dict(
+ type='mmdet.BaseTransformerLayer',
+ attn_cfgs=dict(
+ type='mmdet.MultiScaleDeformableAttention',
+ embed_dims=256,
+ num_heads=8,
+ num_levels=3,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.0,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None),
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.0,
+ act_cfg=dict(type='ReLU', inplace=True)),
+ operation_order=('self_attn', 'norm', 'ffn', 'norm')),
+ init_cfg=None),
+ positional_encoding=dict(
+ type='mmdet.SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ init_cfg=None),
+ enforce_decoder_input_project=False,
+ positional_encoding=dict(
+ type='mmdet.SinePositionalEncoding', num_feats=128,
+ normalize=True),
+ transformer_decoder=dict(
+ type='mmdet.DetrTransformerDecoder',
+ return_intermediate=True,
+ num_layers=9,
+ transformerlayers=dict(
+ type='mmdet.DetrTransformerDecoderLayer',
+ attn_cfgs=dict(
+ type='mmdet.MultiheadAttention',
+ embed_dims=256,
+ num_heads=8,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ dropout_layer=None,
+ batch_first=False),
+ ffn_cfgs=dict(
+ embed_dims=256,
+ feedforward_channels=2048,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.0,
+ dropout_layer=None,
+ add_identity=True),
+ feedforward_channels=2048,
+ operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
+ 'ffn', 'norm')),
+ init_cfg=None),
+ loss_cls=dict(
+ type='mmdet.CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=2.0,
+ reduction='mean',
+ class_weight=[1.0] * num_classes + [0.1]),
+ loss_mask=dict(
+ type='mmdet.CrossEntropyLoss',
+ use_sigmoid=True,
+ reduction='mean',
+ loss_weight=5.0),
+ loss_dice=dict(
+ type='mmdet.DiceLoss',
+ use_sigmoid=True,
+ activate=True,
+ reduction='mean',
+ naive_dice=True,
+ eps=1.0,
+ loss_weight=5.0),
+ train_cfg=dict(
+ num_points=12544,
+ oversample_ratio=3.0,
+ importance_sample_ratio=0.75,
+ assigner=dict(
+ type='mmdet.HungarianAssigner',
+ match_costs=[
+ dict(type='mmdet.ClassificationCost', weight=2.0),
+ dict(
+ type='mmdet.CrossEntropyLossCost',
+ weight=5.0,
+ use_sigmoid=True),
+ dict(
+ type='mmdet.DiceCost',
+ weight=5.0,
+ pred_act=True,
+ eps=1.0)
+ ]),
+ sampler=dict(type='mmdet.MaskPseudoSampler'))),
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+
+# dataset config
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(
+ type='RandomChoiceResize',
+ scales=[int(1024 * x * 0.1) for x in range(5, 21)],
+ resize_type='ResizeShortestEdge',
+ max_size=4096),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
+
+# optimizer
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+optimizer = dict(
+ type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=optimizer,
+ clip_grad=dict(max_norm=0.01, norm_type=2),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi,
+ },
+ norm_decay_mult=0.0))
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ eta_min=0,
+ power=0.9,
+ begin=0,
+ end=90000,
+ by_epoch=False)
+]
+
+# training schedule for 90k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=90000, val_interval=5000)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook', by_epoch=False, interval=5000,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+# Default setting for scaling LR automatically
+# - `enable` means enable scaling LR automatically
+# or not by default.
+# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
+auto_scale_lr = dict(enable=False, base_batch_size=16)
diff --git a/configs/mask2former/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640.py b/configs/mask2former/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640.py
new file mode 100644
index 000000000..56112dfa3
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640.py
@@ -0,0 +1,237 @@
+_base_ = [
+ '../_base_/default_runtime.py', '../_base_/datasets/ade20k_640x640.py'
+]
+
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa
+custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
+
+crop_size = (640, 640)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255,
+ size=crop_size)
+num_classes = 150
+
+depths = [2, 2, 18, 2]
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='SwinTransformer',
+ pretrain_img_size=384,
+ embed_dims=128,
+ depths=depths,
+ num_heads=[4, 8, 16, 32],
+ window_size=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ with_cp=False,
+ frozen_stages=-1,
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
+ decode_head=dict(
+ type='Mask2FormerHead',
+ in_channels=[128, 256, 512, 1024],
+ strides=[4, 8, 16, 32],
+ feat_channels=256,
+ out_channels=256,
+ num_classes=num_classes,
+ num_queries=100,
+ num_transformer_feat_level=3,
+ align_corners=False,
+ pixel_decoder=dict(
+ type='mmdet.MSDeformAttnPixelDecoder',
+ num_outs=3,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ encoder=dict(
+ type='mmdet.DetrTransformerEncoder',
+ num_layers=6,
+ transformerlayers=dict(
+ type='mmdet.BaseTransformerLayer',
+ attn_cfgs=dict(
+ type='mmdet.MultiScaleDeformableAttention',
+ embed_dims=256,
+ num_heads=8,
+ num_levels=3,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.0,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None),
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.0,
+ act_cfg=dict(type='ReLU', inplace=True)),
+ operation_order=('self_attn', 'norm', 'ffn', 'norm')),
+ init_cfg=None),
+ positional_encoding=dict(
+ type='mmdet.SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ init_cfg=None),
+ enforce_decoder_input_project=False,
+ positional_encoding=dict(
+ type='mmdet.SinePositionalEncoding', num_feats=128,
+ normalize=True),
+ transformer_decoder=dict(
+ type='mmdet.DetrTransformerDecoder',
+ return_intermediate=True,
+ num_layers=9,
+ transformerlayers=dict(
+ type='mmdet.DetrTransformerDecoderLayer',
+ attn_cfgs=dict(
+ type='mmdet.MultiheadAttention',
+ embed_dims=256,
+ num_heads=8,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ dropout_layer=None,
+ batch_first=False),
+ ffn_cfgs=dict(
+ embed_dims=256,
+ feedforward_channels=2048,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.0,
+ dropout_layer=None,
+ add_identity=True),
+ feedforward_channels=2048,
+ operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
+ 'ffn', 'norm')),
+ init_cfg=None),
+ loss_cls=dict(
+ type='mmdet.CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=2.0,
+ reduction='mean',
+ class_weight=[1.0] * num_classes + [0.1]),
+ loss_mask=dict(
+ type='mmdet.CrossEntropyLoss',
+ use_sigmoid=True,
+ reduction='mean',
+ loss_weight=5.0),
+ loss_dice=dict(
+ type='mmdet.DiceLoss',
+ use_sigmoid=True,
+ activate=True,
+ reduction='mean',
+ naive_dice=True,
+ eps=1.0,
+ loss_weight=5.0),
+ train_cfg=dict(
+ num_points=12544,
+ oversample_ratio=3.0,
+ importance_sample_ratio=0.75,
+ assigner=dict(
+ type='mmdet.HungarianAssigner',
+ match_costs=[
+ dict(type='mmdet.ClassificationCost', weight=2.0),
+ dict(
+ type='mmdet.CrossEntropyLossCost',
+ weight=5.0,
+ use_sigmoid=True),
+ dict(
+ type='mmdet.DiceCost',
+ weight=5.0,
+ pred_act=True,
+ eps=1.0)
+ ]),
+ sampler=dict(type='mmdet.MaskPseudoSampler'))),
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+
+# dataset config
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(
+ type='RandomChoiceResize',
+ scales=[int(x * 0.1 * 640) for x in range(5, 21)],
+ resize_type='ResizeShortestEdge',
+ max_size=2560),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(batch_size=2, dataset=dict(pipeline=train_pipeline))
+
+# set all layers in backbone to lr_mult=0.1
+# set all norm layers, position_embeding,
+# query_embeding, level_embeding to decay_multi=0.0
+backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
+backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+custom_keys = {
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'backbone.patch_embed.norm': backbone_norm_multi,
+ 'backbone.norm': backbone_norm_multi,
+ 'absolute_pos_embed': backbone_embed_multi,
+ 'relative_position_bias_table': backbone_embed_multi,
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi
+}
+custom_keys.update({
+ f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
+ for stage_id, num_blocks in enumerate(depths)
+ for block_id in range(num_blocks)
+})
+custom_keys.update({
+ f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
+ for stage_id in range(len(depths) - 1)
+})
+# optimizer
+optimizer = dict(
+ type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=optimizer,
+ clip_grad=dict(max_norm=0.01, norm_type=2),
+ paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
+
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ eta_min=0,
+ power=0.9,
+ begin=0,
+ end=160000,
+ by_epoch=False)
+]
+
+# training schedule for 160k
+train_cfg = dict(
+ type='IterBasedTrainLoop', max_iters=160000, val_interval=5000)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook', by_epoch=False, interval=5000,
+ save_best='mIoU'),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+# Default setting for scaling LR automatically
+# - `enable` means enable scaling LR automatically
+# or not by default.
+# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
+auto_scale_lr = dict(enable=False, base_batch_size=16)
diff --git a/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py b/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py
new file mode 100644
index 000000000..f39a3c590
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py
@@ -0,0 +1,5 @@
+_base_ = ['./mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640.py']
+
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth' # noqa
+model = dict(
+ backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=pretrained)))
diff --git a/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py b/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py
new file mode 100644
index 000000000..0c229c145
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py
@@ -0,0 +1,42 @@
+_base_ = ['./mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py']
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth' # noqa
+
+depths = [2, 2, 18, 2]
+model = dict(
+ backbone=dict(
+ pretrain_img_size=384,
+ embed_dims=128,
+ depths=depths,
+ num_heads=[4, 8, 16, 32],
+ window_size=12,
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
+ decode_head=dict(in_channels=[128, 256, 512, 1024]))
+
+# set all layers in backbone to lr_mult=0.1
+# set all norm layers, position_embeding,
+# query_embeding, level_embeding to decay_multi=0.0
+backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
+backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+custom_keys = {
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'backbone.patch_embed.norm': backbone_norm_multi,
+ 'backbone.norm': backbone_norm_multi,
+ 'absolute_pos_embed': backbone_embed_multi,
+ 'relative_position_bias_table': backbone_embed_multi,
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi
+}
+custom_keys.update({
+ f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
+ for stage_id, num_blocks in enumerate(depths)
+ for block_id in range(num_blocks)
+})
+custom_keys.update({
+ f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
+ for stage_id in range(len(depths) - 1)
+})
+# optimizer
+optim_wrapper = dict(
+ paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
diff --git a/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py b/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py
new file mode 100644
index 000000000..f2657e884
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py
@@ -0,0 +1,9 @@
+_base_ = ['./mask2former_swin-b-in1k-384x384-pre_8xb2-160k_ade20k-640x640.py']
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth' # noqa
+
+model = dict(
+ backbone=dict(
+ embed_dims=192,
+ num_heads=[6, 12, 24, 48],
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
+ decode_head=dict(num_queries=100, in_channels=[192, 384, 768, 1536]))
diff --git a/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py b/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py
new file mode 100644
index 000000000..01a7b9988
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py
@@ -0,0 +1,42 @@
+_base_ = ['./mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py']
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth' # noqa
+
+depths = [2, 2, 18, 2]
+model = dict(
+ backbone=dict(
+ pretrain_img_size=384,
+ embed_dims=192,
+ depths=depths,
+ num_heads=[6, 12, 24, 48],
+ window_size=12,
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
+ decode_head=dict(in_channels=[192, 384, 768, 1536]))
+
+# set all layers in backbone to lr_mult=0.1
+# set all norm layers, position_embeding,
+# query_embeding, level_embeding to decay_multi=0.0
+backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
+backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+custom_keys = {
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'backbone.patch_embed.norm': backbone_norm_multi,
+ 'backbone.norm': backbone_norm_multi,
+ 'absolute_pos_embed': backbone_embed_multi,
+ 'relative_position_bias_table': backbone_embed_multi,
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi
+}
+custom_keys.update({
+ f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
+ for stage_id, num_blocks in enumerate(depths)
+ for block_id in range(num_blocks)
+})
+custom_keys.update({
+ f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
+ for stage_id in range(len(depths) - 1)
+})
+# optimizer
+optim_wrapper = dict(
+ paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
diff --git a/configs/mask2former/mask2former_swin-s_8xb2-160k_ade20k-512x512.py b/configs/mask2former/mask2former_swin-s_8xb2-160k_ade20k-512x512.py
new file mode 100644
index 000000000..a7796d569
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-s_8xb2-160k_ade20k-512x512.py
@@ -0,0 +1,37 @@
+_base_ = ['./mask2former_swin-t_8xb2-160k_ade20k-512x512.py']
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth' # noqa
+
+depths = [2, 2, 18, 2]
+model = dict(
+ backbone=dict(
+ depths=depths, init_cfg=dict(type='Pretrained',
+ checkpoint=pretrained)))
+
+# set all layers in backbone to lr_mult=0.1
+# set all norm layers, position_embeding,
+# query_embeding, level_embeding to decay_multi=0.0
+backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
+backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+custom_keys = {
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'backbone.patch_embed.norm': backbone_norm_multi,
+ 'backbone.norm': backbone_norm_multi,
+ 'absolute_pos_embed': backbone_embed_multi,
+ 'relative_position_bias_table': backbone_embed_multi,
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi
+}
+custom_keys.update({
+ f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
+ for stage_id, num_blocks in enumerate(depths)
+ for block_id in range(num_blocks)
+})
+custom_keys.update({
+ f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
+ for stage_id in range(len(depths) - 1)
+})
+# optimizer
+optim_wrapper = dict(
+ paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
diff --git a/configs/mask2former/mask2former_swin-s_8xb2-90k_cityscapes-512x1024.py b/configs/mask2former/mask2former_swin-s_8xb2-90k_cityscapes-512x1024.py
new file mode 100644
index 000000000..5f75544b1
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-s_8xb2-90k_cityscapes-512x1024.py
@@ -0,0 +1,37 @@
+_base_ = ['./mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py']
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth' # noqa
+
+depths = [2, 2, 18, 2]
+model = dict(
+ backbone=dict(
+ depths=depths, init_cfg=dict(type='Pretrained',
+ checkpoint=pretrained)))
+
+# set all layers in backbone to lr_mult=0.1
+# set all norm layers, position_embeding,
+# query_embeding, level_embeding to decay_multi=0.0
+backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
+backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+custom_keys = {
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'backbone.patch_embed.norm': backbone_norm_multi,
+ 'backbone.norm': backbone_norm_multi,
+ 'absolute_pos_embed': backbone_embed_multi,
+ 'relative_position_bias_table': backbone_embed_multi,
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi
+}
+custom_keys.update({
+ f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
+ for stage_id, num_blocks in enumerate(depths)
+ for block_id in range(num_blocks)
+})
+custom_keys.update({
+ f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
+ for stage_id in range(len(depths) - 1)
+})
+# optimizer
+optim_wrapper = dict(
+ paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
diff --git a/configs/mask2former/mask2former_swin-t_8xb2-160k_ade20k-512x512.py b/configs/mask2former/mask2former_swin-t_8xb2-160k_ade20k-512x512.py
new file mode 100644
index 000000000..9de3d242e
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-t_8xb2-160k_ade20k-512x512.py
@@ -0,0 +1,52 @@
+_base_ = ['./mask2former_r50_8xb2-160k_ade20k-512x512.py']
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa
+depths = [2, 2, 6, 2]
+model = dict(
+ backbone=dict(
+ _delete_=True,
+ type='SwinTransformer',
+ embed_dims=96,
+ depths=depths,
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ with_cp=False,
+ frozen_stages=-1,
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
+ decode_head=dict(in_channels=[96, 192, 384, 768]))
+
+# set all layers in backbone to lr_mult=0.1
+# set all norm layers, position_embeding,
+# query_embeding, level_embeding to decay_multi=0.0
+backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
+backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+custom_keys = {
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'backbone.patch_embed.norm': backbone_norm_multi,
+ 'backbone.norm': backbone_norm_multi,
+ 'absolute_pos_embed': backbone_embed_multi,
+ 'relative_position_bias_table': backbone_embed_multi,
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi
+}
+custom_keys.update({
+ f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
+ for stage_id, num_blocks in enumerate(depths)
+ for block_id in range(num_blocks)
+})
+custom_keys.update({
+ f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
+ for stage_id in range(len(depths) - 1)
+})
+# optimizer
+optim_wrapper = dict(
+ paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
diff --git a/configs/mask2former/mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py b/configs/mask2former/mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py
new file mode 100644
index 000000000..0abda6430
--- /dev/null
+++ b/configs/mask2former/mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py
@@ -0,0 +1,52 @@
+_base_ = ['./mask2former_r50_8xb2-90k_cityscapes-512x1024.py']
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa
+depths = [2, 2, 6, 2]
+model = dict(
+ backbone=dict(
+ _delete_=True,
+ type='SwinTransformer',
+ embed_dims=96,
+ depths=depths,
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ with_cp=False,
+ frozen_stages=-1,
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
+ decode_head=dict(in_channels=[96, 192, 384, 768]))
+
+# set all layers in backbone to lr_mult=0.1
+# set all norm layers, position_embeding,
+# query_embeding, level_embeding to decay_multi=0.0
+backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
+backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
+embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
+custom_keys = {
+ 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
+ 'backbone.patch_embed.norm': backbone_norm_multi,
+ 'backbone.norm': backbone_norm_multi,
+ 'absolute_pos_embed': backbone_embed_multi,
+ 'relative_position_bias_table': backbone_embed_multi,
+ 'query_embed': embed_multi,
+ 'query_feat': embed_multi,
+ 'level_embed': embed_multi
+}
+custom_keys.update({
+ f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
+ for stage_id, num_blocks in enumerate(depths)
+ for block_id in range(num_blocks)
+})
+custom_keys.update({
+ f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
+ for stage_id in range(len(depths) - 1)
+})
+# optimizer
+optim_wrapper = dict(
+ paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py
index c6976652d..b18152d7d 100644
--- a/mmseg/models/decode_heads/__init__.py
+++ b/mmseg/models/decode_heads/__init__.py
@@ -15,6 +15,7 @@ from .gc_head import GCHead
from .isa_head import ISAHead
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead
+from .mask2former_head import Mask2FormerHead
from .maskformer_head import MaskFormerHead
from .nl_head import NLHead
from .ocr_head import OCRHead
@@ -37,5 +38,5 @@ __all__ = [
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
- 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead'
+ 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead'
]
diff --git a/mmseg/models/decode_heads/mask2former_head.py b/mmseg/models/decode_heads/mask2former_head.py
new file mode 100644
index 000000000..0ea742430
--- /dev/null
+++ b/mmseg/models/decode_heads/mask2former_head.py
@@ -0,0 +1,162 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ from mmdet.models.dense_heads import \
+ Mask2FormerHead as MMDET_Mask2FormerHead
+except ModuleNotFoundError:
+ MMDET_Mask2FormerHead = None
+
+from mmengine.structures import InstanceData
+from torch import Tensor
+
+from mmseg.registry import MODELS
+from mmseg.structures.seg_data_sample import SegDataSample
+from mmseg.utils import ConfigType, SampleList
+
+
+@MODELS.register_module()
+class Mask2FormerHead(MMDET_Mask2FormerHead):
+ """Implements the Mask2Former head.
+
+ See `Mask2Former: Masked-attention Mask Transformer for Universal Image
+ Segmentation `_ for details.
+
+ Args:
+ num_classes (int): Number of classes. Default: 150.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ ignore_index (int): The label index to be ignored. Default: 255.
+ """
+
+ def __init__(self,
+ num_classes,
+ align_corners=False,
+ ignore_index=255,
+ **kwargs):
+ super().__init__(**kwargs)
+
+ self.num_classes = num_classes
+ self.align_corners = align_corners
+ self.out_channels = num_classes
+ self.ignore_index = ignore_index
+
+ feat_channels = kwargs['feat_channels']
+ self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
+
+ def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
+ """Perform forward propagation to convert paradigm from MMSegmentation
+ to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
+ normally. Specifically, ``batch_gt_instances`` would be added.
+
+ Args:
+ batch_data_samples (List[:obj:`SegDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_sem_seg`.
+
+ Returns:
+ tuple[Tensor]: A tuple contains two lists.
+
+ - batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``labels``, each is
+ unique ground truth label id of images, with
+ shape (num_gt, ) and ``masks``, each is ground truth
+ masks of each instances of a image, shape (num_gt, h, w).
+ - batch_img_metas (list[dict]): List of image meta information.
+ """
+ batch_img_metas = []
+ batch_gt_instances = []
+
+ for data_sample in batch_data_samples:
+ batch_img_metas.append(data_sample.metainfo)
+ gt_sem_seg = data_sample.gt_sem_seg.data
+ classes = torch.unique(
+ gt_sem_seg,
+ sorted=False,
+ return_inverse=False,
+ return_counts=False)
+
+ # remove ignored region
+ gt_labels = classes[classes != self.ignore_index]
+
+ masks = []
+ for class_id in gt_labels:
+ masks.append(gt_sem_seg == class_id)
+
+ if len(masks) == 0:
+ gt_masks = torch.zeros(
+ (0, gt_sem_seg.shape[-2],
+ gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
+ else:
+ gt_masks = torch.stack(masks).squeeze(1).long()
+
+ instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
+ batch_gt_instances.append(instance_data)
+ return batch_gt_instances, batch_img_metas
+
+ def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
+ train_cfg: ConfigType) -> dict:
+ """Perform forward propagation and loss calculation of the decoder head
+ on the features of the upstream network.
+
+ Args:
+ x (tuple[Tensor]): Multi-level features from the upstream
+ network, each is a 4D-tensor.
+ batch_data_samples (List[:obj:`SegDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_sem_seg`.
+ train_cfg (ConfigType): Training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components.
+ """
+ # batch SegDataSample to InstanceDataSample
+ batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
+ batch_data_samples)
+
+ # forward
+ all_cls_scores, all_mask_preds = self(x, batch_data_samples)
+
+ # loss
+ losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
+ batch_gt_instances, batch_img_metas)
+
+ return losses
+
+ def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
+ test_cfg: ConfigType) -> Tuple[Tensor]:
+ """Test without augmentaton.
+
+ Args:
+ x (tuple[Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ batch_img_metas (List[:obj:`SegDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_sem_seg`.
+ test_cfg (ConfigType): Test config.
+
+ Returns:
+ Tensor: A tensor of segmentation mask.
+ """
+ batch_data_samples = [
+ SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
+ ]
+
+ all_cls_scores, all_mask_preds = self(x, batch_data_samples)
+ mask_cls_results = all_cls_scores[-1]
+ mask_pred_results = all_mask_preds[-1]
+ if 'pad_shape' in batch_img_metas[0]:
+ size = batch_img_metas[0]['pad_shape']
+ else:
+ size = batch_img_metas[0]['img_shape']
+ # upsample mask
+ mask_pred_results = F.interpolate(
+ mask_pred_results, size=size, mode='bilinear', align_corners=False)
+ cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
+ mask_pred = mask_pred_results.sigmoid()
+ seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
+ return seg_logits
diff --git a/model-index.yml b/model-index.yml
index 6aacf72b0..ae96bd30f 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -25,6 +25,7 @@ Import:
- configs/isanet/isanet.yml
- configs/knet/knet.yml
- configs/mae/mae.yml
+- configs/mask2former/mask2former.yml
- configs/maskformer/maskformer.yml
- configs/mobilenet_v2/mobilenet_v2.yml
- configs/mobilenet_v3/mobilenet_v3.yml
diff --git a/tests/test_models/test_heads/test_mask2former_head.py b/tests/test_models/test_heads/test_mask2former_head.py
new file mode 100644
index 000000000..079e94ed9
--- /dev/null
+++ b/tests/test_models/test_heads/test_mask2former_head.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmengine import Config
+from mmengine.structures import PixelData
+
+from mmseg.models.decode_heads import Mask2FormerHead
+from mmseg.structures import SegDataSample
+from mmseg.utils import SampleList
+from .utils import to_cuda
+
+
+def test_mask2former_head():
+ num_classes = 19
+ cfg = dict(
+ in_channels=[96, 192, 384, 768],
+ strides=[4, 8, 16, 32],
+ feat_channels=256,
+ out_channels=256,
+ num_classes=num_classes,
+ num_queries=100,
+ num_transformer_feat_level=3,
+ align_corners=False,
+ pixel_decoder=dict(
+ type='mmdet.MSDeformAttnPixelDecoder',
+ num_outs=3,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ encoder=dict(
+ type='mmdet.DetrTransformerEncoder',
+ num_layers=6,
+ transformerlayers=dict(
+ type='mmdet.BaseTransformerLayer',
+ attn_cfgs=dict(
+ type='mmdet.MultiScaleDeformableAttention',
+ embed_dims=256,
+ num_heads=8,
+ num_levels=3,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.0,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None),
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.0,
+ act_cfg=dict(type='ReLU', inplace=True)),
+ operation_order=('self_attn', 'norm', 'ffn', 'norm')),
+ init_cfg=None),
+ positional_encoding=dict(
+ type='mmdet.SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ init_cfg=None),
+ enforce_decoder_input_project=False,
+ positional_encoding=dict(
+ type='mmdet.SinePositionalEncoding', num_feats=128,
+ normalize=True),
+ transformer_decoder=dict(
+ type='mmdet.DetrTransformerDecoder',
+ return_intermediate=True,
+ num_layers=9,
+ transformerlayers=dict(
+ type='mmdet.DetrTransformerDecoderLayer',
+ attn_cfgs=dict(
+ type='mmdet.MultiheadAttention',
+ embed_dims=256,
+ num_heads=8,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ dropout_layer=None,
+ batch_first=False),
+ ffn_cfgs=dict(
+ embed_dims=256,
+ feedforward_channels=2048,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.0,
+ dropout_layer=None,
+ add_identity=True),
+ feedforward_channels=2048,
+ operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
+ 'ffn', 'norm')),
+ init_cfg=None),
+ loss_cls=dict(
+ type='mmdet.CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=2.0,
+ reduction='mean',
+ class_weight=[1.0] * num_classes + [0.1]),
+ loss_mask=dict(
+ type='mmdet.CrossEntropyLoss',
+ use_sigmoid=True,
+ reduction='mean',
+ loss_weight=5.0),
+ loss_dice=dict(
+ type='mmdet.DiceLoss',
+ use_sigmoid=True,
+ activate=True,
+ reduction='mean',
+ naive_dice=True,
+ eps=1.0,
+ loss_weight=5.0),
+ train_cfg=dict(
+ num_points=12544,
+ oversample_ratio=3.0,
+ importance_sample_ratio=0.75,
+ assigner=dict(
+ type='mmdet.HungarianAssigner',
+ match_costs=[
+ dict(type='mmdet.ClassificationCost', weight=2.0),
+ dict(
+ type='mmdet.CrossEntropyLossCost',
+ weight=5.0,
+ use_sigmoid=True),
+ dict(
+ type='mmdet.DiceCost',
+ weight=5.0,
+ pred_act=True,
+ eps=1.0)
+ ]),
+ sampler=dict(type='mmdet.MaskPseudoSampler')))
+ cfg = Config(cfg)
+ head = Mask2FormerHead(**cfg)
+
+ inputs = [
+ torch.rand((2, 96, 8, 8)),
+ torch.rand((2, 192, 4, 4)),
+ torch.rand((2, 384, 2, 2)),
+ torch.rand((2, 768, 1, 1))
+ ]
+
+ data_samples: SampleList = []
+ for i in range(2):
+ data_sample = SegDataSample()
+ img_meta = {}
+ img_meta['img_shape'] = (32, 32)
+ img_meta['ori_shape'] = (32, 32)
+ data_sample.gt_sem_seg = PixelData(
+ data=torch.randint(0, num_classes, (1, 32, 32)))
+ data_sample.set_metainfo(img_meta)
+ data_samples.append(data_sample)
+
+ if torch.cuda.is_available():
+ head, inputs = to_cuda(head, inputs)
+ for data_sample in data_samples:
+ data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda()
+
+ loss_dict = head.loss(inputs, data_samples, None)
+ assert isinstance(loss_dict, dict)
+
+ batch_img_metas = []
+ for data_sample in data_samples:
+ batch_img_metas.append(data_sample.metainfo)
+
+ seg_logits = head.predict(inputs, batch_img_metas, None)
+ assert seg_logits.shape == torch.Size((2, num_classes, 32, 32))