[Feature] Support ISPRS Potsdam Dataset. (#1097)

* add isprs potsdam dataset

* add isprs dataset configs

* fix lint error

* fix potsdam conversion bug

* fix error in potsdam class

* fix error in potsdam class

* add vaihingen dataset

* add vaihingen dataset

* add vaihingen dataset

* fix some description errors.

* fix some description errors.

* fix some description errors.

* upload models & logs of Potsdam

* remove vaihingen and add unit test

* add chinese readme

* add pseudodataset

* use mmcv and add class_names

* use f-string

* add new dataset unittest

* add docstring and remove global variables args

* fix metafile error in PSPNet

* fix pretrained value

* Add dataset info

* fix typo

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
pull/1218/head
Kingdrone 2022-01-18 14:15:15 +08:00 committed by GitHub
parent 0f48c7605d
commit b997a13e28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 616 additions and 3 deletions

View File

@ -0,0 +1,54 @@
# dataset settings
dataset_type = 'PotsdamDataset'
data_root = 'data/potsdam'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline))

View File

@ -104,6 +104,14 @@ Spatial pyramid pooling module or encode-decoder structure are used in deep neur
| DeepLabV3+ | R-50-D8 | 512x512 | 80000 | 7.37 | 6.00 | 50.99 | 50.65 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_loveda/deeplabv3plus_r50-d8_512x512_80k_loveda_20211105_080442-f0720392.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_loveda/deeplabv3plus_r50-d8_512x512_80k_loveda_20211105_080442.log.json) |
| DeepLabV3+ | R-101-D8 | 512x512 | 80000 | 10.84 | 4.33 | 51.47 | 51.32 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda/deeplabv3plus_r101-d8_512x512_80k_loveda_20211105_110759-4c1f297e.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda/deeplabv3plus_r101-d8_512x512_80k_loveda_20211105_110759.log.json) |
### Potsdam
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| DeepLabV3+ | R-18-D8 | 512x512 | 80000 | 1.91 | 81.68 | 77.09 | 78.44 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam/deeplabv3plus_r18-d8_512x512_80k_potsdam_20211219_020601-75fd5bc3.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam/deeplabv3plus_r18-d8_512x512_80k_potsdam_20211219_020601.log.json) |
| DeepLabV3+ | R-50-D8 | 512x512 | 80000 | 7.36 | 26.44 | 78.33 | 79.27 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam/deeplabv3plus_r50-d8_512x512_80k_potsdam_20211219_031508-7e7a2b24.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam/deeplabv3plus_r50-d8_512x512_80k_potsdam_20211219_031508.log.json) |
| DeepLabV3+ | R-101-D8 | 512x512 | 80000 | 10.83 | 17.56 | 78.7 | 79.47 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam/deeplabv3plus_r101-d8_512x512_80k_potsdam_20211219_031508-8b112708.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam/deeplabv3plus_r101-d8_512x512_80k_potsdam_20211219_031508.log.json) |
Note:
- `FP16` means Mixed Precision (FP16) is adopted in training.

View File

@ -8,6 +8,7 @@ Collections:
- Pascal Context
- Pascal Context 59
- LoveDA
- Potsdam
Paper:
URL: https://arxiv.org/abs/1802.02611
Title: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation
@ -669,3 +670,69 @@ Models:
mIoU(ms+flip): 51.32
Config: configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda/deeplabv3plus_r101-d8_512x512_80k_loveda_20211105_110759-4c1f297e.pth
- Name: deeplabv3plus_r18-d8_512x512_80k_potsdam
In Collection: deeplabv3plus
Metadata:
backbone: R-18-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 12.24
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 1.91
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 77.09
mIoU(ms+flip): 78.44
Config: configs/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam/deeplabv3plus_r18-d8_512x512_80k_potsdam_20211219_020601-75fd5bc3.pth
- Name: deeplabv3plus_r50-d8_512x512_80k_potsdam
In Collection: deeplabv3plus
Metadata:
backbone: R-50-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 37.82
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 7.36
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 78.33
mIoU(ms+flip): 79.27
Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam/deeplabv3plus_r50-d8_512x512_80k_potsdam_20211219_031508-7e7a2b24.pth
- Name: deeplabv3plus_r101-d8_512x512_80k_potsdam
In Collection: deeplabv3plus
Metadata:
backbone: R-101-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 56.95
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 10.83
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 78.7
mIoU(ms+flip): 79.47
Config: configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam/deeplabv3plus_r101-d8_512x512_80k_potsdam_20211219_031508-8b112708.pth

View File

@ -0,0 +1,2 @@
_base_ = './deeplabv3plus_r50-d8_512x512_80k_potsdam.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,11 @@
_base_ = './deeplabv3plus_r50-d8_512x512_80k_potsdam.py'
model = dict(
pretrained='open-mmlab://resnet18_v1c',
backbone=dict(depth=18),
decode_head=dict(
c1_in_channels=64,
c1_channels=12,
in_channels=512,
channels=128,
),
auxiliary_head=dict(in_channels=256, channels=64))

View File

@ -0,0 +1,7 @@
_base_ = [
'../_base_/models/deeplabv3plus_r50-d8.py',
'../_base_/datasets/potsdam.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=6), auxiliary_head=dict(num_classes=6))

View File

@ -92,3 +92,11 @@ High-resolution representations are essential for position-sensitive vision prob
| FCN | HRNetV2p-W18-Small | 512x512 | 80000 | 1.59 | 24.87 | 49.28 | 49.42 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_loveda/fcn_hr18s_512x512_80k_loveda_20211210_203228-60a86a7a.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_loveda/fcn_hr18s_512x512_80k_loveda_20211210_203228.log.json) |
| FCN | HRNetV2p-W18 | 512x512 | 80000 | 2.76 | 12.92 | 50.81 | 50.95 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_loveda/fcn_hr18_512x512_80k_loveda_20211210_203952-93d9c3b3.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_loveda/fcn_hr18_512x512_80k_loveda_20211210_203952.log.json) |
| FCN | HRNetV2p-W48 | 512x512 | 80000 | 6.20 | 9.61 | 51.42 | 51.64 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_loveda/fcn_hr48_512x512_80k_loveda_20211211_044756-67072f55.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_loveda/fcn_hr48_512x512_80k_loveda_20211211_044756.log.json) |
### Potsdam
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| FCN | HRNetV2p-W18-Small | 512x512 | 80000 | 1.58 | 36.00 | 77.64 | 78.8 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_potsdam/fcn_hr18s_512x512_80k_potsdam_20211218_205517-ba32af63.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_potsdam/fcn_hr18s_512x512_80k_potsdam_20211218_205517.log.json) |
| FCN | HRNetV2p-W18 | 512x512 | 80000 | 2.76 | 19.25 | 78.26 | 79.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_potsdam/fcn_hr18_512x512_80k_potsdam_20211218_205517-5d0387ad.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_potsdam/fcn_hr18_512x512_80k_potsdam_20211218_205517.log.json) |
| FCN | HRNetV2p-W48 | 512x512 | 80000 | 6.20 | 16.42 | 78.39 | 79.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_potsdam/fcn_hr48_512x512_80k_potsdam_20211219_020601-97434c78.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_potsdam/fcn_hr48_512x512_80k_potsdam_20211219_020601.log.json) |

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/fcn_hr18.py', '../_base_/datasets/potsdam.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(decode_head=dict(num_classes=6))

View File

@ -0,0 +1,9 @@
_base_ = './fcn_hr18_512x512_80k_potsdam.py'
model = dict(
pretrained='open-mmlab://msra/hrnetv2_w18_small',
backbone=dict(
extra=dict(
stage1=dict(num_blocks=(2, )),
stage2=dict(num_blocks=(2, 2)),
stage3=dict(num_modules=3, num_blocks=(2, 2, 2)),
stage4=dict(num_modules=2, num_blocks=(2, 2, 2, 2)))))

View File

@ -0,0 +1,10 @@
_base_ = './fcn_hr18_512x512_80k_potsdam.py'
model = dict(
pretrained='open-mmlab://msra/hrnetv2_w48',
backbone=dict(
extra=dict(
stage2=dict(num_channels=(48, 96)),
stage3=dict(num_channels=(48, 96, 192)),
stage4=dict(num_channels=(48, 96, 192, 384)))),
decode_head=dict(
in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384])))

View File

@ -8,6 +8,7 @@ Collections:
- Pascal Context
- Pascal Context 59
- LoveDA
- Potsdam
Paper:
URL: https://arxiv.org/abs/1908.07919
Title: Deep High-Resolution Representation Learning for Human Pose Estimation
@ -514,3 +515,69 @@ Models:
mIoU(ms+flip): 51.64
Config: configs/hrnet/fcn_hr48_512x512_80k_loveda.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_loveda/fcn_hr48_512x512_80k_loveda_20211211_044756-67072f55.pth
- Name: fcn_hr18s_512x512_80k_potsdam
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W18-Small
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 27.78
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 1.58
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 77.64
mIoU(ms+flip): 78.8
Config: configs/hrnet/fcn_hr18s_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_potsdam/fcn_hr18s_512x512_80k_potsdam_20211218_205517-ba32af63.pth
- Name: fcn_hr18_512x512_80k_potsdam
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W18
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 51.95
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 2.76
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 78.26
mIoU(ms+flip): 79.24
Config: configs/hrnet/fcn_hr18_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_potsdam/fcn_hr18_512x512_80k_potsdam_20211218_205517-5d0387ad.pth
- Name: fcn_hr48_512x512_80k_potsdam
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W48
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 60.9
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 6.2
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 78.39
mIoU(ms+flip): 79.34
Config: configs/hrnet/fcn_hr48_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_potsdam/fcn_hr48_512x512_80k_potsdam_20211219_020601-97434c78.pth

View File

@ -133,6 +133,14 @@ We support evaluation results on these two datasets using models above trained o
| PSPNet | R-50-D8 | 512x512 | 80000 | 6.14 | 6.60 | 50.46 | 50.19 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x512_80k_loveda/pspnet_r50-d8_512x512_80k_loveda_20211104_155728-88610f9f.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x512_80k_loveda/pspnet_r50-d8_512x512_80k_loveda_20211104_155728.log.json) |
| PSPNet | R-101-D8 | 512x512 | 80000 | 9.61 | 4.58 | 51.86 | 51.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x512_80k_loveda/pspnet_r101-d8_512x512_80k_loveda_20211104_153212-1c06c6a8.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x512_80k_loveda/pspnet_r101-d8_512x512_80k_loveda_20211104_153212.log.json) |
### Potsdam
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| PSPNet | R-18-D8 | 512x512 | 80000 | 1.50 | 85.12 | 77.09 | 78.30 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam/pspnet_r18-d8_4x4_512x512_80k_potsdam_20211220_125612-7cd046e1.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam/pspnet_r18-d8_4x4_512x512_80k_potsdam_20211220_125612.log.json) |
| PSPNet | R-50-D8 | 512x512 | 80000 | 6.14 | 30.21 | 78.12 | 78.98 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam/pspnet_r50-d8_4x4_512x512_80k_potsdam_20211219_043541-2dd5fe67.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam/pspnet_r50-d8_4x4_512x512_80k_potsdam_20211219_043541.log.json) |
| PSPNet | R-101-D8 | 512x512 | 80000 | 9.61 | 19.40 | 78.62 | 79.47 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam/pspnet_r101-d8_4x4_512x512_80k_potsdam_20211220_125612-aed036c4.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam/pspnet_r101-d8_4x4_512x512_80k_potsdam_20211220_125612.log.json) |
Note:
- `FP16` means Mixed Precision (FP16) is adopted in training.

View File

@ -11,6 +11,7 @@ Collections:
- COCO-Stuff 10k
- COCO-Stuff 164k
- LoveDA
- Potsdam
Paper:
URL: https://arxiv.org/abs/1612.01105
Title: Pyramid Scene Parsing Network
@ -808,3 +809,69 @@ Models:
mIoU(ms+flip): 51.34
Config: configs/pspnet/pspnet_r101-d8_512x512_80k_loveda.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x512_80k_loveda/pspnet_r101-d8_512x512_80k_loveda_20211104_153212-1c06c6a8.pth
- Name: pspnet_r18-d8_4x4_512x512_80k_potsdam
In Collection: pspnet
Metadata:
backbone: R-18-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 11.75
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 1.5
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 77.09
mIoU(ms+flip): 78.3
Config: configs/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam/pspnet_r18-d8_4x4_512x512_80k_potsdam_20211220_125612-7cd046e1.pth
- Name: pspnet_r50-d8_4x4_512x512_80k_potsdam
In Collection: pspnet
Metadata:
backbone: R-50-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 33.1
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 6.14
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 78.12
mIoU(ms+flip): 78.98
Config: configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam/pspnet_r50-d8_4x4_512x512_80k_potsdam_20211219_043541-2dd5fe67.pth
- Name: pspnet_r101-d8_4x4_512x512_80k_potsdam
In Collection: pspnet
Metadata:
backbone: R-101-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 51.55
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 9.61
Results:
- Task: Semantic Segmentation
Dataset: Potsdam
Metrics:
mIoU: 78.62
mIoU(ms+flip): 79.47
Config: configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam/pspnet_r101-d8_4x4_512x512_80k_potsdam_20211220_125612-aed036c4.pth

View File

@ -0,0 +1,2 @@
_base_ = './pspnet_r50-d8_4x4_512x512_80k_potsdam.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,9 @@
_base_ = './pspnet_r50-d8_4x4_512x512_80k_potsdam.py'
model = dict(
pretrained='open-mmlab://resnet18_v1c',
backbone=dict(depth=18),
decode_head=dict(
in_channels=512,
channels=128,
),
auxiliary_head=dict(in_channels=256, channels=64))

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/potsdam.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=6), auxiliary_head=dict(num_classes=6))

View File

@ -116,6 +116,13 @@ mmsegmentation
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
│ ├── potsdam
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
```
### Cityscapes
@ -286,3 +293,19 @@ python tools/convert_datasets/loveda.py /path/to/loveDA
Using trained model to predict test set of LoveDA and submit it to server can be found [here](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/inference.md).
More details about LoveDA can be found [here](https://github.com/Junjue-Wang/LoveDA).
### ISPRS Potsdam
The [Potsdam](https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-potsdam/)
dataset is for urban semantic segmentation used in the 2D Semantic Labeling Contest - Potsdam.
The dataset can be requested at the challenge [homepage](https://www2.isprs.org/commissions/comm2/wg4/benchmark/data-request-form/).
The '2_Ortho_RGB.zip' and '5_Labels_all_noBoundary.zip' are required.
For Potsdam dataset, please run the following command to download and re-organize the dataset.
```shell
python tools/convert_datasets/potsdam.py /path/to/potsdam
```
In our default setting, it will generate 3456 images for training and 2016 images for validation.

View File

@ -97,6 +97,13 @@ mmsegmentation
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
│ ├── potsdam
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
```
### Cityscapes
@ -228,3 +235,18 @@ python tools/convert_datasets/loveda.py /path/to/loveDA
请参照 [这里](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/inference.md) 来使用训练好的模型去预测 LoveDA 测试集并且提交到官网。
关于 LoveDA 的更多细节可以在[这里](https://github.com/Junjue-Wang/LoveDA) 找到。
### ISPRS Potsdam
[Potsdam](https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-potsdam/)
数据集是一个有着2D 语义分割内容标注的城市遥感数据集。
数据集可以从挑战[主页](https://www2.isprs.org/commissions/comm2/wg4/benchmark/data-request-form/) 获得。
需要其中的 '2_Ortho_RGB.zip' 和 '5_Labels_all_noBoundary.zip'。
对于 Potsdam 数据集,请运行以下命令下载并重新组织数据集
```shell
python tools/convert_datasets/potsdam.py /path/to/potsdam
```
使用我们默认的配置, 将生成 3456 张图片的训练集和 2016 张图片的验证集。

View File

@ -52,6 +52,22 @@ def voc_classes():
]
def loveda_classes():
"""LoveDA class names for external use."""
return [
'background', 'building', 'road', 'water', 'barren', 'forest',
'agricultural'
]
def potsdam_classes():
"""Potsdam class names for external use."""
return [
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
'clutter'
]
def cityscapes_palette():
"""Cityscapes palette for external use."""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
@ -112,10 +128,24 @@ def voc_palette():
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
def loveda_palette():
"""LoveDA palette for external use."""
return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
[159, 129, 183], [0, 255, 0], [255, 195, 128]]
def potsdam_palette():
"""Potsdam palette for external use."""
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
[255, 255, 0], [255, 0, 0]]
dataset_aliases = {
'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'],
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
'loveda': ['loveda'],
'potsdam': ['potsdam']
}

View File

@ -13,6 +13,7 @@ from .hrf import HRFDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .voc import PascalVOCDataset
@ -22,5 +23,6 @@ __all__ = [
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset'
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
'PotsdamDataset'
]

View File

@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class PotsdamDataset(CustomDataset):
"""ISPRS Potsdam dataset.
In segmentation map annotation for Potsdam dataset, 0 is the ignore index.
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
``seg_map_suffix`` are both fixed to '.png'.
"""
CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree',
'car', 'clutter')
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
[255, 255, 0], [255, 0, 0]]
def __init__(self, **kwargs):
super(PotsdamDataset, self).__init__(
img_suffix='.png',
seg_map_suffix='.png',
reduce_zero_label=True,
**kwargs)

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 377 KiB

View File

@ -15,7 +15,7 @@ from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, LoveDADataset,
MultiImageMixDataset, PascalVOCDataset,
RepeatDataset, build_dataset)
PotsdamDataset, RepeatDataset, build_dataset)
def test_classes():
@ -24,6 +24,8 @@ def test_classes():
'pascal_voc')
assert list(
ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
assert list(LoveDADataset.CLASSES) == get_classes('loveda')
assert list(PotsdamDataset.CLASSES) == get_classes('potsdam')
with pytest.raises(ValueError):
get_classes('unsupported')
@ -65,6 +67,8 @@ def test_palette():
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
'pascal_voc')
assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k')
assert LoveDADataset.PALETTE == get_palette('loveda')
assert PotsdamDataset.PALETTE == get_palette('potsdam')
with pytest.raises(ValueError):
get_palette('unsupported')
@ -709,6 +713,16 @@ def test_loveda():
shutil.rmtree('.format_loveda')
def test_potsdam():
test_dataset = PotsdamDataset(
pipeline=[],
img_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_potsdam_dataset/img_dir'),
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_potsdam_dataset/ann_dir'))
assert len(test_dataset) == 1
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))

View File

@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import glob
import math
import os
import os.path as osp
import tempfile
import zipfile
import mmcv
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(
description='Convert potsdam dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='potsdam folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--clip_size',
type=int,
help='clipped size of image after preparation',
default=512)
parser.add_argument(
'--stride_size',
type=int,
help='stride of clipping original images',
default=256)
args = parser.parse_args()
return args
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
# Original image of Potsdam dataset is very large, thus pre-processing
# of them is adopted. Given fixed clip size and stride size to generate
# clipped image, the intersection of width and height is determined.
# For example, given one 5120 x 5120 original image, the clip size is
# 512 and stride size is 256, thus it would generate 20x20 = 400 images
# whose size are all 512x512.
image = mmcv.imread(image_path)
h, w, c = image.shape
clip_size = args.clip_size
stride_size = args.stride_size
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
(h - clip_size) /
stride_size) * stride_size + clip_size >= h else math.ceil(
(h - clip_size) / stride_size) + 1
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
(w - clip_size) /
stride_size) * stride_size + clip_size >= w else math.ceil(
(w - clip_size) / stride_size) + 1
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
xmin = x * clip_size
ymin = y * clip_size
xmin = xmin.ravel()
ymin = ymin.ravel()
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
np.zeros_like(xmin))
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
np.zeros_like(ymin))
boxes = np.stack([
xmin + xmin_offset, ymin + ymin_offset,
np.minimum(xmin + clip_size, w),
np.minimum(ymin + clip_size, h)
],
axis=1)
if to_label:
color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
[255, 255, 0], [0, 255, 0], [0, 255, 255],
[0, 0, 255]])
flatten_v = np.matmul(
image.reshape(-1, c),
np.array([2, 3, 4]).reshape(3, 1))
out = np.zeros_like(flatten_v)
for idx, class_color in enumerate(color_map):
value_idx = np.matmul(class_color,
np.array([2, 3, 4]).reshape(3, 1))
out[flatten_v == value_idx] = idx
image = out.reshape(h, w)
for box in boxes:
start_x, start_y, end_x, end_y = box
clipped_image = image[start_y:end_y,
start_x:end_x] if to_label else image[
start_y:end_y, start_x:end_x, :]
idx_i, idx_j = osp.basename(image_path).split('_')[2:4]
mmcv.imwrite(
clipped_image.astype(np.uint8),
osp.join(
clip_save_dir,
f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
def main():
args = parse_args()
splits = {
'train': [
'2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11',
'4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7',
'6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9'
],
'val': [
'5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13',
'4_15', '2_14', '5_13', '4_13', '3_14', '7_13'
]
}
dataset_path = args.dataset_path
if args.out_dir is None:
out_dir = osp.join('data', 'potsdam')
else:
out_dir = args.out_dir
print('Making directories...')
mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
print('Find the data', zipp_list)
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for zipp in zipp_list:
zip_file = zipfile.ZipFile(zipp)
zip_file.extractall(tmp_dir)
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
if not len(src_path_list):
sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif'))
prog_bar = mmcv.ProgressBar(len(src_path_list))
for i, src_path in enumerate(src_path_list):
idx_i, idx_j = osp.basename(src_path).split('_')[2:4]
data_type = 'train' if f'{idx_i}_{idx_j}' in splits[
'train'] else 'val'
if 'label' in src_path:
dst_dir = osp.join(out_dir, 'ann_dir', data_type)
clip_big_image(src_path, dst_dir, args, to_label=True)
else:
dst_dir = osp.join(out_dir, 'img_dir', data_type)
clip_big_image(src_path, dst_dir, args, to_label=False)
prog_bar.update()
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()