Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation into open-mmlab-master
# Conflicts: # docs/model_zoo.md # mmseg/models/backbones/__init__.py # tests/test_models/test_backbone.pypull/58/head
commit
3cbfbf6434
|
@ -47,6 +47,9 @@ jobs:
|
||||||
- torch: 1.5.0+cu101
|
- torch: 1.5.0+cu101
|
||||||
torchvision: 0.6.0+cu101
|
torchvision: 0.6.0+cu101
|
||||||
python-version: 3.7
|
python-version: 3.7
|
||||||
|
- torch: 1.6.0+cu101
|
||||||
|
torchvision: 0.7.0+cu101
|
||||||
|
python-version: 3.7
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
|
@ -103,7 +103,6 @@ venv.bak/
|
||||||
# mypy
|
# mypy
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
|
|
||||||
mmseg/version.py
|
|
||||||
data
|
data
|
||||||
.vscode
|
.vscode
|
||||||
.idea
|
.idea
|
||||||
|
|
|
@ -44,7 +44,8 @@ This project is released under the [Apache 2.0 license](LICENSE).
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
v0.5.0 was released in 10/7/2020.
|
v0.5.1 was released in 11/08/2020.
|
||||||
|
Please refer to [changelog.md](docs/changelog.md) for details and release history.
|
||||||
|
|
||||||
## Benchmark and model zoo
|
## Benchmark and model zoo
|
||||||
|
|
||||||
|
@ -53,7 +54,8 @@ Results and models are available in the [model zoo](docs/model_zoo.md).
|
||||||
Supported backbones:
|
Supported backbones:
|
||||||
- [x] ResNet
|
- [x] ResNet
|
||||||
- [x] ResNeXt
|
- [x] ResNeXt
|
||||||
- [x] HRNet
|
- [x] [HRNet](configs/hrnet/README.md)
|
||||||
|
- [x] [ResNeSt](configs/resnest/README.md)
|
||||||
|
|
||||||
Supported methods:
|
Supported methods:
|
||||||
- [x] [FCN](configs/fcn)
|
- [x] [FCN](configs/fcn)
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
# model settings
|
||||||
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
|
model = dict(
|
||||||
|
type='CascadeEncoderDecoder',
|
||||||
|
num_stages=2,
|
||||||
|
pretrained='open-mmlab://resnet50_v1c',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNetV1c',
|
||||||
|
depth=50,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
dilations=(1, 1, 2, 4),
|
||||||
|
strides=(1, 2, 1, 1),
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
norm_eval=False,
|
||||||
|
style='pytorch',
|
||||||
|
contract_dilation=True),
|
||||||
|
decode_head=[
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=1024,
|
||||||
|
in_index=2,
|
||||||
|
channels=256,
|
||||||
|
num_convs=1,
|
||||||
|
concat_input=False,
|
||||||
|
drop_out_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||||
|
dict(
|
||||||
|
type='OCRHead',
|
||||||
|
in_channels=2048,
|
||||||
|
in_index=3,
|
||||||
|
channels=512,
|
||||||
|
ocr_channels=256,
|
||||||
|
drop_out_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||||
|
])
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg = dict()
|
||||||
|
test_cfg = dict(mode='whole')
|
|
@ -1,18 +1,28 @@
|
||||||
# Object-Contextual Representations for Semantic Segmentation
|
# Object-Contextual Representations for Semantic Segmentation
|
||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
```
|
```
|
||||||
@article{yuan2019ocr,
|
@article{YuanW18,
|
||||||
|
title={Ocnet: Object context network for scene parsing},
|
||||||
|
author={Yuhui Yuan and Jingdong Wang},
|
||||||
|
booktitle={arXiv preprint arXiv:1809.00916},
|
||||||
|
year={2018}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{YuanCW20,
|
||||||
title={Object-Contextual Representations for Semantic Segmentation},
|
title={Object-Contextual Representations for Semantic Segmentation},
|
||||||
author={Yuan Yuhui and Chen Xilin and Wang Jingdong},
|
author={Yuhui Yuan and Xilin Chen and Jingdong Wang},
|
||||||
journal={arXiv preprint arXiv:1909.11065},
|
booktitle={ECCV},
|
||||||
year={2019}
|
year={2020}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Results and models
|
## Results and models
|
||||||
|
|
||||||
### Cityscapes
|
### Cityscapes
|
||||||
|
|
||||||
|
#### HRNet backbone
|
||||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||||
|--------|--------------------|-----------|--------:|----------|----------------|------:|--------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------|--------------------|-----------|--------:|----------|----------------|------:|--------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| OCRNet | HRNetV2p-W18-Small | 512x1024 | 40000 | 3.5 | 10.45 | 74.30 | 75.95 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr18s_512x1024_40k_cityscapes/ocrnet_hr18s_512x1024_40k_cityscapes_20200601_033304-fa2436c2.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr18s_512x1024_40k_cityscapes/ocrnet_hr18s_512x1024_40k_cityscapes_20200601_033304.log.json) |
|
| OCRNet | HRNetV2p-W18-Small | 512x1024 | 40000 | 3.5 | 10.45 | 74.30 | 75.95 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr18s_512x1024_40k_cityscapes/ocrnet_hr18s_512x1024_40k_cityscapes_20200601_033304-fa2436c2.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr18s_512x1024_40k_cityscapes/ocrnet_hr18s_512x1024_40k_cityscapes_20200601_033304.log.json) |
|
||||||
|
@ -25,6 +35,16 @@
|
||||||
| OCRNet | HRNetV2p-W18 | 512x1024 | 160000 | - | - | 79.47 | 80.91 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr18_512x1024_160k_cityscapes/ocrnet_hr18_512x1024_160k_cityscapes_20200602_191001-b9172d0c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr18_512x1024_160k_cityscapes/ocrnet_hr18_512x1024_160k_cityscapes_20200602_191001.log.json) |
|
| OCRNet | HRNetV2p-W18 | 512x1024 | 160000 | - | - | 79.47 | 80.91 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr18_512x1024_160k_cityscapes/ocrnet_hr18_512x1024_160k_cityscapes_20200602_191001-b9172d0c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr18_512x1024_160k_cityscapes/ocrnet_hr18_512x1024_160k_cityscapes_20200602_191001.log.json) |
|
||||||
| OCRNet | HRNetV2p-W48 | 512x1024 | 160000 | - | - | 81.35 | 82.70 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr48_512x1024_160k_cityscapes/ocrnet_hr48_512x1024_160k_cityscapes_20200602_191037-dfbf1b0c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr48_512x1024_160k_cityscapes/ocrnet_hr48_512x1024_160k_cityscapes_20200602_191037.log.json) |
|
| OCRNet | HRNetV2p-W48 | 512x1024 | 160000 | - | - | 81.35 | 82.70 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr48_512x1024_160k_cityscapes/ocrnet_hr48_512x1024_160k_cityscapes_20200602_191037-dfbf1b0c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_hr48_512x1024_160k_cityscapes/ocrnet_hr48_512x1024_160k_cityscapes_20200602_191037.log.json) |
|
||||||
|
|
||||||
|
|
||||||
|
#### ResNet backbone
|
||||||
|
|
||||||
|
| Method | Backbone | Crop Size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||||
|
|--------|--------------------|-----------|--------|----------|-----------|----------------|------|--------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| OCRNet | R-101-D8 | 512x1024 | 8 | 40000 | - | - | 80.09 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_r101-d8_512x1024_40k_b8_cityscapes/ocrnet_r101-d8_512x1024_40k_b8_cityscapes-02ac0f13.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_r101-d8_512x1024_40k_b8_cityscapes/ocrnet_r101-d8_512x1024_40k_b8_cityscapes_20200717_110721.log.json) |
|
||||||
|
| OCRNet | R-101-D8 | 512x1024 | 16 | 40000 | 8.8 | 3.02 | 80.30 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_r101-d8_512x1024_40k_b16_cityscapes/ocrnet_r101-d8_512x1024_40k_b16_cityscapes-db500f80.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_r101-d8_512x1024_40k_b16_cityscapes/ocrnet_r101-d8_512x1024_40k_b16_cityscapes_20200723_193726.log.json) |
|
||||||
|
| OCRNet | R-101-D8 | 512x1024 | 16 | 80000 | 8.8 | 3.02 | 80.81 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_r101-d8_512x1024_80k_b16_cityscapes/ocrnet_r101-d8_512x1024_80k_b16_cityscapes-78688424.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/ocrnet/ocrnet_r101-d8_512x1024_80k_b16_cityscapes/ocrnet_r101-d8_512x1024_80k_b16_cityscapes_20200723_192421.log.json) |
|
||||||
|
|
||||||
|
|
||||||
### ADE20K
|
### ADE20K
|
||||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||||
|--------|--------------------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------|--------------------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/ocrnet_r50-d8.py',
|
||||||
|
'../_base_/datasets/cityscapes.py',
|
||||||
|
'../_base_/default_runtime.py',
|
||||||
|
'../_base_/schedules/schedule_40k.py'
|
||||||
|
]
|
||||||
|
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
||||||
|
optimizer = dict(lr=0.02)
|
||||||
|
lr_config = dict(min_lr=2e-4)
|
|
@ -0,0 +1,7 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/ocrnet_r50-d8.py',
|
||||||
|
'../_base_/datasets/cityscapes.py',
|
||||||
|
'../_base_/default_runtime.py',
|
||||||
|
'../_base_/schedules/schedule_40k.py'
|
||||||
|
]
|
||||||
|
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/ocrnet_r50-d8.py',
|
||||||
|
'../_base_/datasets/cityscapes.py',
|
||||||
|
'../_base_/default_runtime.py',
|
||||||
|
'../_base_/schedules/schedule_80k.py'
|
||||||
|
]
|
||||||
|
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
||||||
|
optimizer = dict(lr=0.02)
|
||||||
|
lr_config = dict(min_lr=2e-4)
|
|
@ -0,0 +1,30 @@
|
||||||
|
# ResNeSt: Split-Attention Networks
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{zhang2020resnest,
|
||||||
|
title={ResNeSt: Split-Attention Networks},
|
||||||
|
author={Zhang, Hang and Wu, Chongruo and Zhang, Zhongyue and Zhu, Yi and Zhang, Zhi and Lin, Haibin and Sun, Yue and He, Tong and Muller, Jonas and Manmatha, R. and Li, Mu and Smola, Alexander},
|
||||||
|
journal={arXiv preprint arXiv:2004.08955},
|
||||||
|
year={2020}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### Cityscapes
|
||||||
|
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||||
|
|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| FCN | S-101-D8 | 512x1024 | 80000 | 11.4 | 2.39 | 77.56 | 78.98 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x1024_80k_cityscapes/fcn_s101-d8_512x1024_80k_cityscapes_20200807_140631-f8d155b3.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x1024_80k_cityscapes/fcn_s101-d8_512x1024_80k_cityscapes-20200807_140631.log.json) |
|
||||||
|
| PSPNet | S-101-D8 | 512x1024 | 80000 | 11.8 | 2.52 | 78.57 | 79.19 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x1024_80k_cityscapes/pspnet_s101-d8_512x1024_80k_cityscapes_20200807_140631-c75f3b99.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x1024_80k_cityscapes/pspnet_s101-d8_512x1024_80k_cityscapes-20200807_140631.log.json) |
|
||||||
|
| DeepLabV3 | S-101-D8 | 512x1024 | 80000 | 11.9 | 1.88 | 79.67 | 80.51 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes/deeplabv3_s101-d8_512x1024_80k_cityscapes_20200807_144429-b73c4270.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x1024_80k_cityscapes/deeplabv3_s101-d8_512x1024_80k_cityscapes-20200807_144429.log.json) |
|
||||||
|
| DeepLabV3+ | S-101-D8 | 512x1024 | 80000 | 13.2 | 2.36 | 79.62 | 80.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes/deeplabv3plus_s101-d8_512x1024_80k_cityscapes_20200807_144429-1239eb43.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x1024_80k_cityscapes/deeplabv3plus_s101-d8_512x1024_80k_cityscapes-20200807_144429.log.json) |
|
||||||
|
|
||||||
|
### ADE20k
|
||||||
|
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||||
|
|------------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| FCN | S-101-D8 | 512x512 | 160000 | 14.2 | 12.86 | 45.62 | 46.16 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x512_160k_ade20k/fcn_s101-d8_512x512_160k_ade20k_20200807_145416-d3160329.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/fcn_s101-d8_512x512_160k_ade20k/fcn_s101-d8_512x512_160k_ade20k-20200807_145416.log.json) |
|
||||||
|
| PSPNet | S-101-D8 | 512x512 | 160000 | 14.2 | 13.02 | 45.44 | 46.28 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x512_160k_ade20k/pspnet_s101-d8_512x512_160k_ade20k_20200807_145416-a6daa92a.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/pspnet_s101-d8_512x512_160k_ade20k/pspnet_s101-d8_512x512_160k_ade20k-20200807_145416.log.json) |
|
||||||
|
| DeepLabV3 | S-101-D8 | 512x512 | 160000 | 14.6 | 9.28 | 45.71 | 46.59 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x512_160k_ade20k/deeplabv3_s101-d8_512x512_160k_ade20k_20200807_144503-17ecabe5.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3_s101-d8_512x512_160k_ade20k/deeplabv3_s101-d8_512x512_160k_ade20k-20200807_144503.log.json) |
|
||||||
|
| DeepLabV3+ | S-101-D8 | 512x512 | 160000 | 16.2 | 11.96 | 46.47 | 47.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k/deeplabv3plus_s101-d8_512x512_160k_ade20k_20200807_144503-27b26226.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/resnest/deeplabv3plus_s101-d8_512x512_160k_ade20k/deeplabv3plus_s101-d8_512x512_160k_ade20k-20200807_144503.log.json) |
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = '../deeplabv3/deeplabv3_r101-d8_512x1024_80k_cityscapes.py'
|
||||||
|
model = dict(
|
||||||
|
pretrained='open-mmlab://resnest101',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNeSt',
|
||||||
|
stem_channels=128,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True))
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = '../deeplabv3/deeplabv3_r101-d8_512x512_160k_ade20k.py'
|
||||||
|
model = dict(
|
||||||
|
pretrained='open-mmlab://resnest101',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNeSt',
|
||||||
|
stem_channels=128,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True))
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py'
|
||||||
|
model = dict(
|
||||||
|
pretrained='open-mmlab://resnest101',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNeSt',
|
||||||
|
stem_channels=128,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True))
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x512_160k_ade20k.py'
|
||||||
|
model = dict(
|
||||||
|
pretrained='open-mmlab://resnest101',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNeSt',
|
||||||
|
stem_channels=128,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True))
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = '../fcn/fcn_r101-d8_512x1024_80k_cityscapes.py'
|
||||||
|
model = dict(
|
||||||
|
pretrained='open-mmlab://resnest101',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNeSt',
|
||||||
|
stem_channels=128,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True))
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = '../fcn/fcn_r101-d8_512x512_160k_ade20k.py'
|
||||||
|
model = dict(
|
||||||
|
pretrained='open-mmlab://resnest101',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNeSt',
|
||||||
|
stem_channels=128,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True))
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = '../pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py'
|
||||||
|
model = dict(
|
||||||
|
pretrained='open-mmlab://resnest101',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNeSt',
|
||||||
|
stem_channels=128,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True))
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = '../pspnet/pspnet_r101-d8_512x512_160k_ade20k.py'
|
||||||
|
model = dict(
|
||||||
|
pretrained='open-mmlab://resnest101',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNeSt',
|
||||||
|
stem_channels=128,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True))
|
|
@ -0,0 +1,15 @@
|
||||||
|
## Changelog
|
||||||
|
|
||||||
|
### v0.5.1 (11/08/2020)
|
||||||
|
**Highlights**
|
||||||
|
- Support FP16 and more generalized OHEM
|
||||||
|
**Bug Fixes**
|
||||||
|
- Fixed Pascal VOC conversion script (#19)
|
||||||
|
- Fixed OHEM weight assign bug (#54)
|
||||||
|
- Fixed palette type when palette is not given (#27)
|
||||||
|
**New Features**
|
||||||
|
- Support FP16 (#21)
|
||||||
|
- Generalized OHEM (#54)
|
||||||
|
**Improvements**
|
||||||
|
- Add load-from flag (#33)
|
||||||
|
- Fixed training tricks doc about different learning rates of model (#26)
|
11
docs/conf.py
11
docs/conf.py
|
@ -20,10 +20,17 @@ sys.path.insert(0, os.path.abspath('..'))
|
||||||
project = 'MMSegmentation'
|
project = 'MMSegmentation'
|
||||||
copyright = '2020-2020, OpenMMLab'
|
copyright = '2020-2020, OpenMMLab'
|
||||||
author = 'MMSegmentation Authors'
|
author = 'MMSegmentation Authors'
|
||||||
|
version_file = '../mmseg/version.py'
|
||||||
|
|
||||||
|
|
||||||
|
def get_version():
|
||||||
|
with open(version_file, 'r') as f:
|
||||||
|
exec(compile(f.read(), version_file, 'exec'))
|
||||||
|
return locals()['__version__']
|
||||||
|
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
with open('../mmseg/VERSION', 'r') as f:
|
release = get_version()
|
||||||
release = f.read().strip()
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -332,3 +332,18 @@ python tools/publish_model.py work_dirs/pspnet/latest.pth psp_r50_hszhao_200ep.p
|
||||||
```
|
```
|
||||||
|
|
||||||
The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pth`.
|
The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pth`.
|
||||||
|
|
||||||
|
### Convert to ONNX (experimental)
|
||||||
|
|
||||||
|
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
||||||
|
|
||||||
|
## Tutorials
|
||||||
|
|
||||||
|
Currently, we provide four tutorials for users to [add new dataset](tutorials/new_dataset.md), [design data pipeline](tutorials/data_pipeline.md) and [add new modules](tutorials/new_modules.md), [use training tricks](tutorials/training_tricks.md).
|
||||||
|
We also provide a full description about the [config system](config.md).
|
||||||
|
|
|
@ -54,10 +54,9 @@ pip install -e . # or "python setup.py develop"
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
|
||||||
1. In `dev` mode, the git commit id will be written to the version number with step *d*, e.g. 0.5.0+c415a2e. The version will also be saved in trained models.
|
1. The `version+git_hash` will also be saved in trained models meta, e.g. 0.5.0+c415a2e.
|
||||||
It is recommended that you run step *d* each time you pull some updates from github. If C++/CUDA codes are modified, then this step is compulsory.
|
|
||||||
|
|
||||||
2. When MMsegmentation is installed on `dev` mode, any local modifications made to the code will take effect without the need to reinstall it (unless you submit some commits and want to update the version number).
|
2. When MMsegmentation is installed on `dev` mode, any local modifications made to the code will take effect without the need to reinstall it.
|
||||||
|
|
||||||
3. If you would like to use `opencv-python-headless` instead of `opencv-python`,
|
3. If you would like to use `opencv-python-headless` instead of `opencv-python`,
|
||||||
you can install it before installing MMCV.
|
you can install it before installing MMCV.
|
||||||
|
|
2724
docs/model_zoo.json
2724
docs/model_zoo.json
File diff suppressed because it is too large
Load Diff
|
@ -85,6 +85,10 @@ Please refer to [OCRNet](https://github.com/open-mmlab/mmsegmentation/blob/maste
|
||||||
|
|
||||||
Please refer to [Fast-SCNN](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fastscnn) for details.
|
Please refer to [Fast-SCNN](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fastscnn) for details.
|
||||||
|
|
||||||
|
### ResNeSt
|
||||||
|
|
||||||
|
Please refer to [ResNeSt](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/resnest) for details.
|
||||||
|
|
||||||
### Mixed Precision (FP16) Training
|
### Mixed Precision (FP16) Training
|
||||||
|
|
||||||
Please refer [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fp16/README.md) for details.
|
Please refer [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fp16/README.md) for details.
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
0.5.0
|
|
|
@ -1,3 +1,30 @@
|
||||||
from .version import __version__, short_version, version_info
|
import mmcv
|
||||||
|
|
||||||
__all__ = ['__version__', 'short_version', 'version_info']
|
from .version import __version__, version_info
|
||||||
|
|
||||||
|
MMCV_MIN = '1.0.5'
|
||||||
|
MMCV_MAX = '1.0.5'
|
||||||
|
|
||||||
|
|
||||||
|
def digit_version(version_str):
|
||||||
|
digit_version = []
|
||||||
|
for x in version_str.split('.'):
|
||||||
|
if x.isdigit():
|
||||||
|
digit_version.append(int(x))
|
||||||
|
elif x.find('rc') != -1:
|
||||||
|
patch_version = x.split('rc')
|
||||||
|
digit_version.append(int(patch_version[0]) - 1)
|
||||||
|
digit_version.append(int(patch_version[1]))
|
||||||
|
return digit_version
|
||||||
|
|
||||||
|
|
||||||
|
mmcv_min_version = digit_version(MMCV_MIN)
|
||||||
|
mmcv_max_version = digit_version(MMCV_MAX)
|
||||||
|
mmcv_version = digit_version(mmcv.__version__)
|
||||||
|
|
||||||
|
|
||||||
|
assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
|
||||||
|
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||||
|
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
|
||||||
|
|
||||||
|
__all__ = ['__version__', 'version_info']
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from .fast_scnn import FastSCNN
|
from .fast_scnn import FastSCNN
|
||||||
from .hrnet import HRNet
|
from .hrnet import HRNet
|
||||||
|
from .resnest import ResNeSt
|
||||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||||
from .resnext import ResNeXt
|
from .resnext import ResNeXt
|
||||||
|
|
||||||
__all__ = ['ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN']
|
__all__ = [
|
||||||
|
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||||
|
'ResNeSt'
|
||||||
|
]
|
||||||
|
|
|
@ -0,0 +1,314 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||||
|
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
from ..utils import ResLayer
|
||||||
|
from .resnet import Bottleneck as _Bottleneck
|
||||||
|
from .resnet import ResNetV1d
|
||||||
|
|
||||||
|
|
||||||
|
class RSoftmax(nn.Module):
|
||||||
|
"""Radix Softmax module in ``SplitAttentionConv2d``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
radix (int): Radix of input.
|
||||||
|
groups (int): Groups of input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, radix, groups):
|
||||||
|
super().__init__()
|
||||||
|
self.radix = radix
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch = x.size(0)
|
||||||
|
if self.radix > 1:
|
||||||
|
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
|
||||||
|
x = F.softmax(x, dim=1)
|
||||||
|
x = x.reshape(batch, -1)
|
||||||
|
else:
|
||||||
|
x = torch.sigmoid(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SplitAttentionConv2d(nn.Module):
|
||||||
|
"""Split-Attention Conv2d in ResNeSt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Same as nn.Conv2d.
|
||||||
|
out_channels (int): Same as nn.Conv2d.
|
||||||
|
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
||||||
|
stride (int | tuple[int]): Same as nn.Conv2d.
|
||||||
|
padding (int | tuple[int]): Same as nn.Conv2d.
|
||||||
|
dilation (int | tuple[int]): Same as nn.Conv2d.
|
||||||
|
groups (int): Same as nn.Conv2d.
|
||||||
|
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||||
|
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
|
||||||
|
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||||
|
which means using conv2d.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||||
|
dcn (dict): Config dict for DCN. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
dcn=None):
|
||||||
|
super(SplitAttentionConv2d, self).__init__()
|
||||||
|
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
||||||
|
self.radix = radix
|
||||||
|
self.groups = groups
|
||||||
|
self.channels = channels
|
||||||
|
self.with_dcn = dcn is not None
|
||||||
|
self.dcn = dcn
|
||||||
|
fallback_on_stride = False
|
||||||
|
if self.with_dcn:
|
||||||
|
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||||
|
if self.with_dcn and not fallback_on_stride:
|
||||||
|
assert conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||||
|
conv_cfg = dcn
|
||||||
|
self.conv = build_conv_layer(
|
||||||
|
conv_cfg,
|
||||||
|
in_channels,
|
||||||
|
channels * radix,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups * radix,
|
||||||
|
bias=False)
|
||||||
|
self.norm0_name, norm0 = build_norm_layer(
|
||||||
|
norm_cfg, channels * radix, postfix=0)
|
||||||
|
self.add_module(self.norm0_name, norm0)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.fc1 = build_conv_layer(
|
||||||
|
None, channels, inter_channels, 1, groups=self.groups)
|
||||||
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
|
norm_cfg, inter_channels, postfix=1)
|
||||||
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
self.fc2 = build_conv_layer(
|
||||||
|
None, inter_channels, channels * radix, 1, groups=self.groups)
|
||||||
|
self.rsoftmax = RSoftmax(radix, groups)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm0(self):
|
||||||
|
"""nn.Module: the normalization layer named "norm0" """
|
||||||
|
return getattr(self, self.norm0_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm1(self):
|
||||||
|
"""nn.Module: the normalization layer named "norm1" """
|
||||||
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.norm0(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
batch, rchannel = x.shape[:2]
|
||||||
|
batch = x.size(0)
|
||||||
|
if self.radix > 1:
|
||||||
|
splits = x.view(batch, self.radix, -1, *x.shape[2:])
|
||||||
|
gap = splits.sum(dim=1)
|
||||||
|
else:
|
||||||
|
gap = x
|
||||||
|
gap = F.adaptive_avg_pool2d(gap, 1)
|
||||||
|
gap = self.fc1(gap)
|
||||||
|
|
||||||
|
gap = self.norm1(gap)
|
||||||
|
gap = self.relu(gap)
|
||||||
|
|
||||||
|
atten = self.fc2(gap)
|
||||||
|
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
||||||
|
|
||||||
|
if self.radix > 1:
|
||||||
|
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
|
||||||
|
out = torch.sum(attens * splits, dim=1)
|
||||||
|
else:
|
||||||
|
out = atten * x
|
||||||
|
return out.contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(_Bottleneck):
|
||||||
|
"""Bottleneck block for ResNeSt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inplane (int): Input planes of this block.
|
||||||
|
planes (int): Middle planes of this block.
|
||||||
|
groups (int): Groups of conv2.
|
||||||
|
width_per_group (int): Width per group of conv2. 64x4d indicates
|
||||||
|
``groups=64, width_per_group=4`` and 32x8d indicates
|
||||||
|
``groups=32, width_per_group=8``.
|
||||||
|
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||||
|
reduction_factor (int): Reduction factor of inter_channels in
|
||||||
|
SplitAttentionConv2d. Default: 4.
|
||||||
|
avg_down_stride (bool): Whether to use average pool for stride in
|
||||||
|
Bottleneck. Default: True.
|
||||||
|
kwargs (dict): Key word arguments for base class.
|
||||||
|
"""
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
inplanes,
|
||||||
|
planes,
|
||||||
|
groups=1,
|
||||||
|
base_width=4,
|
||||||
|
base_channels=64,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True,
|
||||||
|
**kwargs):
|
||||||
|
"""Bottleneck block for ResNeSt."""
|
||||||
|
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
|
||||||
|
|
||||||
|
if groups == 1:
|
||||||
|
width = self.planes
|
||||||
|
else:
|
||||||
|
width = math.floor(self.planes *
|
||||||
|
(base_width / base_channels)) * groups
|
||||||
|
|
||||||
|
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
|
||||||
|
|
||||||
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
|
self.norm_cfg, width, postfix=1)
|
||||||
|
self.norm3_name, norm3 = build_norm_layer(
|
||||||
|
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||||
|
|
||||||
|
self.conv1 = build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
self.inplanes,
|
||||||
|
width,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=self.conv1_stride,
|
||||||
|
bias=False)
|
||||||
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
self.with_modulated_dcn = False
|
||||||
|
self.conv2 = SplitAttentionConv2d(
|
||||||
|
width,
|
||||||
|
width,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1 if self.avg_down_stride else self.conv2_stride,
|
||||||
|
padding=self.dilation,
|
||||||
|
dilation=self.dilation,
|
||||||
|
groups=groups,
|
||||||
|
radix=radix,
|
||||||
|
reduction_factor=reduction_factor,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
dcn=self.dcn)
|
||||||
|
delattr(self, self.norm2_name)
|
||||||
|
|
||||||
|
if self.avg_down_stride:
|
||||||
|
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
|
||||||
|
|
||||||
|
self.conv3 = build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
width,
|
||||||
|
self.planes * self.expansion,
|
||||||
|
kernel_size=1,
|
||||||
|
bias=False)
|
||||||
|
self.add_module(self.norm3_name, norm3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
def _inner_forward(x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.norm1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
if self.with_plugins:
|
||||||
|
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
|
||||||
|
if self.avg_down_stride:
|
||||||
|
out = self.avd_layer(out)
|
||||||
|
|
||||||
|
if self.with_plugins:
|
||||||
|
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.norm3(out)
|
||||||
|
|
||||||
|
if self.with_plugins:
|
||||||
|
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
out = cp.checkpoint(_inner_forward, x)
|
||||||
|
else:
|
||||||
|
out = _inner_forward(x)
|
||||||
|
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class ResNeSt(ResNetV1d):
|
||||||
|
"""ResNeSt backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groups (int): Number of groups of Bottleneck. Default: 1
|
||||||
|
base_width (int): Base width of Bottleneck. Default: 4
|
||||||
|
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||||
|
reduction_factor (int): Reduction factor of inter_channels in
|
||||||
|
SplitAttentionConv2d. Default: 4.
|
||||||
|
avg_down_stride (bool): Whether to use average pool for stride in
|
||||||
|
Bottleneck. Default: True.
|
||||||
|
kwargs (dict): Keyword arguments for ResNet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
arch_settings = {
|
||||||
|
50: (Bottleneck, (3, 4, 6, 3)),
|
||||||
|
101: (Bottleneck, (3, 4, 23, 3)),
|
||||||
|
152: (Bottleneck, (3, 8, 36, 3)),
|
||||||
|
200: (Bottleneck, (3, 24, 36, 3))
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
groups=1,
|
||||||
|
base_width=4,
|
||||||
|
radix=2,
|
||||||
|
reduction_factor=4,
|
||||||
|
avg_down_stride=True,
|
||||||
|
**kwargs):
|
||||||
|
self.groups = groups
|
||||||
|
self.base_width = base_width
|
||||||
|
self.radix = radix
|
||||||
|
self.reduction_factor = reduction_factor
|
||||||
|
self.avg_down_stride = avg_down_stride
|
||||||
|
super(ResNeSt, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def make_res_layer(self, **kwargs):
|
||||||
|
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||||
|
return ResLayer(
|
||||||
|
groups=self.groups,
|
||||||
|
base_width=self.base_width,
|
||||||
|
base_channels=self.base_channels,
|
||||||
|
radix=self.radix,
|
||||||
|
reduction_factor=self.reduction_factor,
|
||||||
|
avg_down_stride=self.avg_down_stride,
|
||||||
|
**kwargs)
|
|
@ -1,3 +1,4 @@
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
@ -171,6 +172,8 @@ class EncoderDecoder(BaseSegmentor):
|
||||||
h_stride, w_stride = self.test_cfg.stride
|
h_stride, w_stride = self.test_cfg.stride
|
||||||
h_crop, w_crop = self.test_cfg.crop_size
|
h_crop, w_crop = self.test_cfg.crop_size
|
||||||
batch_size, _, h_img, w_img = img.size()
|
batch_size, _, h_img, w_img = img.size()
|
||||||
|
assert h_crop <= h_img and w_crop <= w_img, (
|
||||||
|
'crop size should not greater than image size')
|
||||||
num_classes = self.num_classes
|
num_classes = self.num_classes
|
||||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||||
|
@ -185,14 +188,15 @@ class EncoderDecoder(BaseSegmentor):
|
||||||
y1 = max(y2 - h_crop, 0)
|
y1 = max(y2 - h_crop, 0)
|
||||||
x1 = max(x2 - w_crop, 0)
|
x1 = max(x2 - w_crop, 0)
|
||||||
crop_img = img[:, :, y1:y2, x1:x2]
|
crop_img = img[:, :, y1:y2, x1:x2]
|
||||||
pad_img = crop_img.new_zeros(
|
crop_seg_logit = self.encode_decode(crop_img, img_meta)
|
||||||
(crop_img.size(0), crop_img.size(1), h_crop, w_crop))
|
preds += F.pad(crop_seg_logit,
|
||||||
pad_img[:, :, :y2 - y1, :x2 - x1] = crop_img
|
(int(x1), int(preds.shape[3] - x2), int(y1),
|
||||||
pad_seg_logit = self.encode_decode(pad_img, img_meta)
|
int(preds.shape[2] - y2)))
|
||||||
preds[:, :, y1:y2,
|
|
||||||
x1:x2] += pad_seg_logit[:, :, :y2 - y1, :x2 - x1]
|
|
||||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||||
assert (count_mat == 0).sum() == 0
|
assert (count_mat == 0).sum() == 0
|
||||||
|
# We want to regard count_mat as a constant while exporting to ONNX
|
||||||
|
count_mat = torch.from_numpy(count_mat.detach().numpy())
|
||||||
preds = preds / count_mat
|
preds = preds / count_mat
|
||||||
if rescale:
|
if rescale:
|
||||||
preds = resize(
|
preds = resize(
|
||||||
|
@ -201,7 +205,6 @@ class EncoderDecoder(BaseSegmentor):
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
warning=False)
|
warning=False)
|
||||||
|
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def whole_inference(self, img, img_meta, rescale):
|
def whole_inference(self, img, img_meta, rescale):
|
||||||
|
@ -243,8 +246,8 @@ class EncoderDecoder(BaseSegmentor):
|
||||||
seg_logit = self.whole_inference(img, img_meta, rescale)
|
seg_logit = self.whole_inference(img, img_meta, rescale)
|
||||||
output = F.softmax(seg_logit, dim=1)
|
output = F.softmax(seg_logit, dim=1)
|
||||||
flip = img_meta[0]['flip']
|
flip = img_meta[0]['flip']
|
||||||
flip_direction = img_meta[0]['flip_direction']
|
|
||||||
if flip:
|
if flip:
|
||||||
|
flip_direction = img_meta[0]['flip_direction']
|
||||||
assert flip_direction in ['horizontal', 'vertical']
|
assert flip_direction in ['horizontal', 'vertical']
|
||||||
if flip_direction == 'horizontal':
|
if flip_direction == 'horizontal':
|
||||||
output = output.flip(dims=(3, ))
|
output = output.flip(dims=(3, ))
|
||||||
|
@ -257,6 +260,8 @@ class EncoderDecoder(BaseSegmentor):
|
||||||
"""Simple test with single image."""
|
"""Simple test with single image."""
|
||||||
seg_logit = self.inference(img, img_meta, rescale)
|
seg_logit = self.inference(img, img_meta, rescale)
|
||||||
seg_pred = seg_logit.argmax(dim=1)
|
seg_pred = seg_logit.argmax(dim=1)
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
return seg_pred
|
||||||
seg_pred = seg_pred.cpu().numpy()
|
seg_pred = seg_pred.cpu().numpy()
|
||||||
# unravel batch dim
|
# unravel batch dim
|
||||||
seg_pred = list(seg_pred)
|
seg_pred = list(seg_pred)
|
||||||
|
|
|
@ -42,8 +42,7 @@ class ResLayer(nn.Sequential):
|
||||||
if stride != 1 or inplanes != planes * block.expansion:
|
if stride != 1 or inplanes != planes * block.expansion:
|
||||||
downsample = []
|
downsample = []
|
||||||
conv_stride = stride
|
conv_stride = stride
|
||||||
# check dilation for dilated ResNet
|
if avg_down:
|
||||||
if avg_down and (stride != 1 or dilation != 1):
|
|
||||||
conv_stride = 1
|
conv_stride = 1
|
||||||
downsample.append(
|
downsample.append(
|
||||||
nn.AvgPool2d(
|
nn.AvgPool2d(
|
||||||
|
|
|
@ -7,7 +7,7 @@ import cv2
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from mmcv.utils.parrots_wrapper import get_build_config
|
from mmcv.utils import get_build_config, get_git_hash
|
||||||
|
|
||||||
import mmseg
|
import mmseg
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ def collect_env():
|
||||||
env_info['OpenCV'] = cv2.__version__
|
env_info['OpenCV'] = cv2.__version__
|
||||||
|
|
||||||
env_info['MMCV'] = mmcv.__version__
|
env_info['MMCV'] = mmcv.__version__
|
||||||
env_info['MMSegmentation'] = mmseg.__version__
|
env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
|
||||||
try:
|
try:
|
||||||
from mmcv.ops import get_compiler_version, get_compiling_cuda_version
|
from mmcv.ops import get_compiler_version, get_compiling_cuda_version
|
||||||
env_info['MMCV Compiler'] = get_compiler_version()
|
env_info['MMCV Compiler'] = get_compiler_version()
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Open-MMLab. All rights reserved.
|
||||||
|
|
||||||
|
__version__ = '0.5.1'
|
||||||
|
|
||||||
|
|
||||||
|
def parse_version_info(version_str):
|
||||||
|
version_info = []
|
||||||
|
for x in version_str.split('.'):
|
||||||
|
if x.isdigit():
|
||||||
|
version_info.append(int(x))
|
||||||
|
elif x.find('rc') != -1:
|
||||||
|
patch_version = x.split('rc')
|
||||||
|
version_info.append(int(patch_version[0]))
|
||||||
|
version_info.append(f'rc{patch_version[1]}')
|
||||||
|
return tuple(version_info)
|
||||||
|
|
||||||
|
|
||||||
|
version_info = parse_version_info(__version__)
|
|
@ -8,6 +8,6 @@ line_length = 79
|
||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = setuptools
|
known_standard_library = setuptools
|
||||||
known_first_party = mmseg
|
known_first_party = mmseg
|
||||||
known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,pytablewriter,pytest,scipy,torch,torchvision
|
known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnxruntime,pytablewriter,pytest,scipy,torch,torchvision
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
|
81
setup.py
81
setup.py
|
@ -1,81 +1,19 @@
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open('README.md', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
version_file = 'mmseg/version.py'
|
version_file = 'mmseg/version.py'
|
||||||
|
|
||||||
|
|
||||||
def get_git_hash():
|
|
||||||
|
|
||||||
def _minimal_ext_cmd(cmd):
|
|
||||||
# construct minimal environment
|
|
||||||
env = {}
|
|
||||||
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
|
||||||
v = os.environ.get(k)
|
|
||||||
if v is not None:
|
|
||||||
env[k] = v
|
|
||||||
# LANGUAGE is used on win32
|
|
||||||
env['LANGUAGE'] = 'C'
|
|
||||||
env['LANG'] = 'C'
|
|
||||||
env['LC_ALL'] = 'C'
|
|
||||||
out = subprocess.Popen(
|
|
||||||
cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
|
||||||
return out
|
|
||||||
|
|
||||||
try:
|
|
||||||
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
|
||||||
sha = out.strip().decode('ascii')
|
|
||||||
except OSError:
|
|
||||||
sha = 'unknown'
|
|
||||||
|
|
||||||
return sha
|
|
||||||
|
|
||||||
|
|
||||||
def get_hash():
|
|
||||||
if os.path.exists('.git'):
|
|
||||||
sha = get_git_hash()[:7]
|
|
||||||
elif os.path.exists(version_file):
|
|
||||||
try:
|
|
||||||
from mmseg.version import __version__
|
|
||||||
sha = __version__.split('+')[-1]
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError('Unable to get git version')
|
|
||||||
else:
|
|
||||||
sha = 'unknown'
|
|
||||||
|
|
||||||
return sha
|
|
||||||
|
|
||||||
|
|
||||||
def write_version_py():
|
|
||||||
content = """# GENERATED VERSION FILE
|
|
||||||
# TIME: {}
|
|
||||||
|
|
||||||
__version__ = '{}'
|
|
||||||
short_version = '{}'
|
|
||||||
version_info = ({})
|
|
||||||
"""
|
|
||||||
sha = get_hash()
|
|
||||||
with open('mmseg/VERSION', 'r') as f:
|
|
||||||
SHORT_VERSION = f.read().strip()
|
|
||||||
VERSION_INFO = ', '.join(SHORT_VERSION.split('.'))
|
|
||||||
VERSION = SHORT_VERSION + '+' + sha
|
|
||||||
|
|
||||||
version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,
|
|
||||||
VERSION_INFO)
|
|
||||||
with open(version_file, 'w') as f:
|
|
||||||
f.write(version_file_str)
|
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
with open(version_file, 'r') as f:
|
with open(version_file, 'r') as f:
|
||||||
exec(compile(f.read(), version_file, 'exec'))
|
exec(compile(f.read(), version_file, 'exec'))
|
||||||
import sys
|
return locals()['__version__']
|
||||||
# return short version for sdist
|
|
||||||
if 'sdist' in sys.argv or 'bdist_wheel' in sys.argv:
|
|
||||||
return locals()['short_version']
|
|
||||||
else:
|
|
||||||
return locals()['__version__']
|
|
||||||
|
|
||||||
|
|
||||||
def parse_requirements(fname='requirements.txt', with_version=True):
|
def parse_requirements(fname='requirements.txt', with_version=True):
|
||||||
|
@ -155,11 +93,12 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
write_version_py()
|
|
||||||
setup(
|
setup(
|
||||||
name='mmsegmentation',
|
name='mmsegmentation',
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
description='Open MMLab Semantic Segmentation Toolbox and Benchmark',
|
description='Open MMLab Semantic Segmentation Toolbox and Benchmark',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
author='MMSegmentation Authors',
|
author='MMSegmentation Authors',
|
||||||
author_email='openmmlab@gmail.com',
|
author_email='openmmlab@gmail.com',
|
||||||
keywords='computer vision, semantic segmentation',
|
keywords='computer vision, semantic segmentation',
|
||||||
|
|
|
@ -4,7 +4,9 @@ from mmcv.ops import DeformConv2dPack
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
from torch.nn.modules import AvgPool2d, GroupNorm
|
from torch.nn.modules import AvgPool2d, GroupNorm
|
||||||
|
|
||||||
from mmseg.models.backbones import FastSCNN, ResNet, ResNetV1d, ResNeXt
|
from mmseg.models.backbones import (FastSCNN, ResNeSt, ResNet, ResNetV1d,
|
||||||
|
ResNeXt)
|
||||||
|
from mmseg.models.backbones.resnest import Bottleneck as BottleneckS
|
||||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||||
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
||||||
from mmseg.models.utils import ResLayer
|
from mmseg.models.utils import ResLayer
|
||||||
|
@ -689,3 +691,41 @@ def test_fastscnn_backbone():
|
||||||
assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
|
assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
|
||||||
# FFM output
|
# FFM output
|
||||||
assert feat[2].shape == torch.Size([batch_size, 128, 64, 128])
|
assert feat[2].shape == torch.Size([batch_size, 128, 64, 128])
|
||||||
|
|
||||||
|
|
||||||
|
def test_resnest_bottleneck():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# Style must be in ['pytorch', 'caffe']
|
||||||
|
BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow')
|
||||||
|
|
||||||
|
# Test ResNeSt Bottleneck structure
|
||||||
|
block = BottleneckS(
|
||||||
|
64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch')
|
||||||
|
assert block.avd_layer.stride == 2
|
||||||
|
assert block.conv2.channels == 256
|
||||||
|
|
||||||
|
# Test ResNeSt Bottleneck forward
|
||||||
|
block = BottleneckS(64, 16, radix=2, reduction_factor=4)
|
||||||
|
x = torch.randn(2, 64, 56, 56)
|
||||||
|
x_out = block(x)
|
||||||
|
assert x_out.shape == torch.Size([2, 64, 56, 56])
|
||||||
|
|
||||||
|
|
||||||
|
def test_resnest_backbone():
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
# ResNeSt depth should be in [50, 101, 152, 200]
|
||||||
|
ResNeSt(depth=18)
|
||||||
|
|
||||||
|
# Test ResNeSt with radix 2, reduction_factor 4
|
||||||
|
model = ResNeSt(
|
||||||
|
depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(2, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 4
|
||||||
|
assert feat[0].shape == torch.Size([2, 256, 56, 56])
|
||||||
|
assert feat[1].shape == torch.Size([2, 512, 28, 28])
|
||||||
|
assert feat[2].shape == torch.Size([2, 1024, 14, 14])
|
||||||
|
assert feat[3].shape == torch.Size([2, 2048, 7, 7])
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as rt
|
||||||
|
import torch
|
||||||
|
import torch._C
|
||||||
|
import torch.serialization
|
||||||
|
from mmcv.onnx import register_extra_symbolics
|
||||||
|
from mmcv.runner import load_checkpoint
|
||||||
|
|
||||||
|
from mmseg.models import build_segmentor
|
||||||
|
|
||||||
|
torch.manual_seed(3)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_batchnorm(module):
|
||||||
|
module_output = module
|
||||||
|
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||||
|
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
||||||
|
module.momentum, module.affine,
|
||||||
|
module.track_running_stats)
|
||||||
|
if module.affine:
|
||||||
|
module_output.weight.data = module.weight.data.clone().detach()
|
||||||
|
module_output.bias.data = module.bias.data.clone().detach()
|
||||||
|
# keep requires_grad unchanged
|
||||||
|
module_output.weight.requires_grad = module.weight.requires_grad
|
||||||
|
module_output.bias.requires_grad = module.bias.requires_grad
|
||||||
|
module_output.running_mean = module.running_mean
|
||||||
|
module_output.running_var = module.running_var
|
||||||
|
module_output.num_batches_tracked = module.num_batches_tracked
|
||||||
|
for name, child in module.named_children():
|
||||||
|
module_output.add_module(name, _convert_batchnorm(child))
|
||||||
|
del module
|
||||||
|
return module_output
|
||||||
|
|
||||||
|
|
||||||
|
def _demo_mm_inputs(input_shape, num_classes):
|
||||||
|
"""Create a superset of inputs needed to run test or train batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shape (tuple):
|
||||||
|
input batch dimensions
|
||||||
|
num_classes (int):
|
||||||
|
number of semantic classes
|
||||||
|
"""
|
||||||
|
(N, C, H, W) = input_shape
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
imgs = rng.rand(*input_shape)
|
||||||
|
segs = rng.randint(
|
||||||
|
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
||||||
|
img_metas = [{
|
||||||
|
'img_shape': (H, W, C),
|
||||||
|
'ori_shape': (H, W, C),
|
||||||
|
'pad_shape': (H, W, C),
|
||||||
|
'filename': '<demo>.png',
|
||||||
|
'scale_factor': 1.0,
|
||||||
|
'flip': False,
|
||||||
|
} for _ in range(N)]
|
||||||
|
mm_inputs = {
|
||||||
|
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
||||||
|
'img_metas': img_metas,
|
||||||
|
'gt_semantic_seg': torch.LongTensor(segs)
|
||||||
|
}
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch2onnx(model,
|
||||||
|
input_shape,
|
||||||
|
opset_version=11,
|
||||||
|
show=False,
|
||||||
|
output_file='tmp.onnx',
|
||||||
|
verify=False):
|
||||||
|
"""Export Pytorch model to ONNX model and verify the outputs are same
|
||||||
|
between Pytorch and ONNX.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Pytorch model we want to export.
|
||||||
|
input_shape (tuple): Use this input shape to construct
|
||||||
|
the corresponding dummy input and execute the model.
|
||||||
|
opset_version (int): The onnx op version. Default: 11.
|
||||||
|
show (bool): Whether print the computation graph. Default: False.
|
||||||
|
output_file (string): The path to where we store the output ONNX model.
|
||||||
|
Default: `tmp.onnx`.
|
||||||
|
verify (bool): Whether compare the outputs between Pytorch and ONNX.
|
||||||
|
Default: False.
|
||||||
|
"""
|
||||||
|
model.cpu().eval()
|
||||||
|
|
||||||
|
num_classes = model.decode_head.num_classes
|
||||||
|
|
||||||
|
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||||
|
|
||||||
|
imgs = mm_inputs.pop('imgs')
|
||||||
|
img_metas = mm_inputs.pop('img_metas')
|
||||||
|
|
||||||
|
img_list = [img[None, :] for img in imgs]
|
||||||
|
img_meta_list = [[img_meta] for img_meta in img_metas]
|
||||||
|
|
||||||
|
# replace original forward function
|
||||||
|
origin_forward = model.forward
|
||||||
|
model.forward = partial(
|
||||||
|
model.forward, img_metas=img_meta_list, return_loss=False)
|
||||||
|
|
||||||
|
register_extra_symbolics(opset_version)
|
||||||
|
with torch.no_grad():
|
||||||
|
torch.onnx.export(
|
||||||
|
model, (img_list, ),
|
||||||
|
output_file,
|
||||||
|
export_params=True,
|
||||||
|
keep_initializers_as_inputs=True,
|
||||||
|
verbose=show,
|
||||||
|
opset_version=opset_version)
|
||||||
|
print(f'Successfully exported ONNX model: {output_file}')
|
||||||
|
model.forward = origin_forward
|
||||||
|
|
||||||
|
if verify:
|
||||||
|
# check by onnx
|
||||||
|
import onnx
|
||||||
|
onnx_model = onnx.load(output_file)
|
||||||
|
onnx.checker.check_model(onnx_model)
|
||||||
|
|
||||||
|
# check the numerical value
|
||||||
|
# get pytorch output
|
||||||
|
pytorch_result = model(img_list, img_meta_list, return_loss=False)[0]
|
||||||
|
|
||||||
|
# get onnx output
|
||||||
|
input_all = [node.name for node in onnx_model.graph.input]
|
||||||
|
input_initializer = [
|
||||||
|
node.name for node in onnx_model.graph.initializer
|
||||||
|
]
|
||||||
|
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||||
|
assert (len(net_feed_input) == 1)
|
||||||
|
sess = rt.InferenceSession(output_file)
|
||||||
|
onnx_result = sess.run(
|
||||||
|
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
|
||||||
|
if not np.allclose(pytorch_result, onnx_result):
|
||||||
|
raise ValueError(
|
||||||
|
'The outputs are different between Pytorch and ONNX')
|
||||||
|
print('The outputs are same between Pytorch and ONNX')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Convert MMDet to ONNX')
|
||||||
|
parser.add_argument('config', help='test config file path')
|
||||||
|
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
||||||
|
parser.add_argument('--show', action='store_true', help='show onnx graph')
|
||||||
|
parser.add_argument(
|
||||||
|
'--verify', action='store_true', help='verify the onnx model')
|
||||||
|
parser.add_argument('--output-file', type=str, default='tmp.onnx')
|
||||||
|
parser.add_argument('--opset-version', type=int, default=11)
|
||||||
|
parser.add_argument(
|
||||||
|
'--shape',
|
||||||
|
type=int,
|
||||||
|
nargs='+',
|
||||||
|
default=[256, 256],
|
||||||
|
help='input image size')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if len(args.shape) == 1:
|
||||||
|
input_shape = (1, 3, args.shape[0], args.shape[0])
|
||||||
|
elif len(args.shape) == 2:
|
||||||
|
input_shape = (
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
) + tuple(args.shape)
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid input shape')
|
||||||
|
|
||||||
|
cfg = mmcv.Config.fromfile(args.config)
|
||||||
|
cfg.model.pretrained = None
|
||||||
|
|
||||||
|
# build the model and load checkpoint
|
||||||
|
segmentor = build_segmentor(
|
||||||
|
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
|
||||||
|
# convert SyncBN to BN
|
||||||
|
segmentor = _convert_batchnorm(segmentor)
|
||||||
|
|
||||||
|
num_classes = segmentor.decode_head.num_classes
|
||||||
|
|
||||||
|
if args.checkpoint:
|
||||||
|
checkpoint = load_checkpoint(
|
||||||
|
segmentor, args.checkpoint, map_location='cpu')
|
||||||
|
|
||||||
|
# conver model to onnx file
|
||||||
|
pytorch2onnx(
|
||||||
|
segmentor,
|
||||||
|
input_shape,
|
||||||
|
opset_version=args.opset_version,
|
||||||
|
show=args.show,
|
||||||
|
output_file=args.output_file,
|
||||||
|
verify=args.verify)
|
|
@ -7,7 +7,7 @@ import time
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import init_dist
|
from mmcv.runner import init_dist
|
||||||
from mmcv.utils import Config, DictAction
|
from mmcv.utils import Config, DictAction, get_git_hash
|
||||||
|
|
||||||
from mmseg import __version__
|
from mmseg import __version__
|
||||||
from mmseg.apis import set_random_seed, train_segmentor
|
from mmseg.apis import set_random_seed, train_segmentor
|
||||||
|
@ -141,7 +141,7 @@ def main():
|
||||||
# save mmseg version, config file content and class names in
|
# save mmseg version, config file content and class names in
|
||||||
# checkpoints as meta data
|
# checkpoints as meta data
|
||||||
cfg.checkpoint_config.meta = dict(
|
cfg.checkpoint_config.meta = dict(
|
||||||
mmseg_version=__version__,
|
mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
|
||||||
config=cfg.pretty_text,
|
config=cfg.pretty_text,
|
||||||
CLASSES=datasets[0].CLASSES,
|
CLASSES=datasets[0].CLASSES,
|
||||||
PALETTE=datasets[0].PALETTE)
|
PALETTE=datasets[0].PALETTE)
|
||||||
|
|
Loading…
Reference in New Issue