[Feature] Support iSAID aerial dataset. (#1115)

* support iSAID aerial dataset

* Update and rename docs/dataset_prepare.md to 博士/dataset_prepare.md

* Update dataset_prepare.md

* fix typo

* fix typo

* fix typo

* remove imgviz

* fix wrong order in annotation name

* upload models&logs

* upload models&logs

* add load_annotations

* fix unittest coverage

* fix unittest coverage

* fix correct crop size in config

* fix iSAID unit test

* fix iSAID unit test

* fix typos

* fix wrong crop size in readme

* use smaller figure as test data

* add smaller dataset in test data

* add blank in docs

* use 0 bytes pseudo data

* add footnote and comments for crop size

* change iSAID to isaid and add default value in it

* change iSAID to isaid in _base_

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
This commit is contained in:
Waterman0524 2022-02-17 19:07:32 +08:00 committed by GitHub
parent 9522b4fc97
commit 4f4e7728b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 783 additions and 6 deletions

View File

@ -138,6 +138,7 @@ Supported datasets:
- [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#loveda) - [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#loveda)
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-potsdam) - [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-potsdam)
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-vaihingen) - [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-vaihingen)
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isaid)
## Installation ## Installation

View File

@ -137,6 +137,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#loveda) - [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#loveda)
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-potsdam) - [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-potsdam)
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-vaihingen) - [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-vaihingen)
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isaid)
## 安装 ## 安装

View File

@ -0,0 +1,62 @@
# dataset settings
dataset_type = 'iSAIDDataset'
data_root = 'data/iSAID'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
"""
This crop_size setting is followed by the implementation of
`PointFlow: Flowing Semantics Through Points for Aerial Image
Segmentation <https://arxiv.org/pdf/2103.06564.pdf>`_.
"""
crop_size = (896, 896)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(896, 896),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline))

View File

@ -114,8 +114,16 @@ Spatial pyramid pooling module or encode-decoder structure are used in deep neur
| DeepLabV3+ | R-50-D8 | 512x512 | 80000 | 7.36 | 26.91 | 73.97 | 75.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816-5040938d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816.log.json) | | DeepLabV3+ | R-50-D8 | 512x512 | 80000 | 7.36 | 26.91 | 73.97 | 75.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816-5040938d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816.log.json) |
| DeepLabV3+ | R-101-D8 | 512x512 | 80000 | 10.83 | 18.59 | 73.06 | 74.14 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816.log.json) | | DeepLabV3+ | R-101-D8 | 512x512 | 80000 | 10.83 | 18.59 | 73.06 | 74.14 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816.log.json) |
### iSAID
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| DeepLabV3+ | R-18-D8 | 896x896 | 80000 | 6.19 | 24.81 | 61.35 | 62.61 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526-7059991d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
| DeepLabV3+ | R-50-D8 | 896x896 | 80000 | 21.45 | 8.42 | 67.06 | 68.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
Note: Note:
- `D-8`/`D-16` here corresponding to the output stride 8/16 setting for DeepLab series. - `D-8`/`D-16` here corresponding to the output stride 8/16 setting for DeepLab series.
- `MG-124` stands for multi-grid dilation in the last stage of ResNet. - `MG-124` stands for multi-grid dilation in the last stage of ResNet.
- `FP16` means Mixed Precision (FP16) is adopted in training. - `FP16` means Mixed Precision (FP16) is adopted in training.
- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf)

View File

@ -10,6 +10,7 @@ Collections:
- LoveDA - LoveDA
- Potsdam - Potsdam
- Vaihingen - Vaihingen
- iSAID
Paper: Paper:
URL: https://arxiv.org/abs/1802.02611 URL: https://arxiv.org/abs/1802.02611
Title: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation Title: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation
@ -803,3 +804,47 @@ Models:
mIoU(ms+flip): 74.14 mIoU(ms+flip): 74.14
Config: configs/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen.py Config: configs/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth
- Name: deeplabv3plus_r18-d8_4x4_896x896_80k_isaid
In Collection: deeplabv3plus
Metadata:
backbone: R-18-D8
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 40.31
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 6.19
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 61.35
mIoU(ms+flip): 62.61
Config: configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526-7059991d.pth
- Name: deeplabv3plus_r50-d8_4x4_896x896_80k_isaid
In Collection: deeplabv3plus
Metadata:
backbone: R-50-D8
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 118.76
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 21.45
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 67.06
mIoU(ms+flip): 68.02
Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth

View File

@ -0,0 +1,11 @@
_base_ = './deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py'
model = dict(
pretrained='open-mmlab://resnet18_v1c',
backbone=dict(depth=18),
decode_head=dict(
c1_in_channels=64,
c1_channels=12,
in_channels=512,
channels=128,
),
auxiliary_head=dict(in_channels=256, channels=64))

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/deeplabv3plus_r50-d8.py', '../_base_/datasets/isaid.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=16), auxiliary_head=dict(num_classes=16))

View File

@ -107,3 +107,15 @@ High-resolution representations are essential for position-sensitive vision prob
| FCN | HRNetV2p-W18-Small | 512x512 | 80000 | 1.58 | 38.11 | 71.81 | 73.1 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen/fcn_hr18s_4x4_512x512_80k_vaihingen_20211231_230909-b23aae02.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen/fcn_hr18s_4x4_512x512_80k_vaihingen_20211231_230909.log.json) | | FCN | HRNetV2p-W18-Small | 512x512 | 80000 | 1.58 | 38.11 | 71.81 | 73.1 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen/fcn_hr18s_4x4_512x512_80k_vaihingen_20211231_230909-b23aae02.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen/fcn_hr18s_4x4_512x512_80k_vaihingen_20211231_230909.log.json) |
| FCN | HRNetV2p-W18 | 512x512 | 80000 | 2.76 | 19.55 | 72.57 | 74.09 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen/fcn_hr18_4x4_512x512_80k_vaihingen_20211231_231216-2ec3ae8a.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen/fcn_hr18_4x4_512x512_80k_vaihingen_20211231_231216.log.json) | | FCN | HRNetV2p-W18 | 512x512 | 80000 | 2.76 | 19.55 | 72.57 | 74.09 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen/fcn_hr18_4x4_512x512_80k_vaihingen_20211231_231216-2ec3ae8a.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen/fcn_hr18_4x4_512x512_80k_vaihingen_20211231_231216.log.json) |
| FCN | HRNetV2p-W48 | 512x512 | 80000 | 6.20 | 17.25 | 72.50 | 73.52 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244-7133cb22.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244.log.json) | | FCN | HRNetV2p-W48 | 512x512 | 80000 | 6.20 | 17.25 | 72.50 | 73.52 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244-7133cb22.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244.log.json) |
### iSAID
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| FCN | HRNetV2p-W18-Small | 896x896 | 80000 | 4.95 | 13.84 | 62.30 | 62.97 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603-3cc0769b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603.log.json) |
| FCN | HRNetV2p-W18 | 896x896 | 80000 | 8.30 | 7.71 | 65.06 | 65.60 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230-49bf752e.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230.log.json) |
| FCN | HRNetV2p-W48 | 896x896 | 80000 | 16.89 | 7.34 | 67.80 | 68.53 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643-547fc420.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643.log.json) |
Note:
- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf)

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/fcn_hr18.py', '../_base_/datasets/isaid.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(decode_head=dict(num_classes=16))

View File

@ -0,0 +1,9 @@
_base_ = './fcn_hr18_4x4_896x896_80k_isaid.py'
model = dict(
pretrained='open-mmlab://msra/hrnetv2_w18_small',
backbone=dict(
extra=dict(
stage1=dict(num_blocks=(2, )),
stage2=dict(num_blocks=(2, 2)),
stage3=dict(num_modules=3, num_blocks=(2, 2, 2)),
stage4=dict(num_modules=2, num_blocks=(2, 2, 2, 2)))))

View File

@ -0,0 +1,10 @@
_base_ = './fcn_hr18_4x4_896x896_80k_isaid.py'
model = dict(
pretrained='open-mmlab://msra/hrnetv2_w48',
backbone=dict(
extra=dict(
stage2=dict(num_channels=(48, 96)),
stage3=dict(num_channels=(48, 96, 192)),
stage4=dict(num_channels=(48, 96, 192, 384)))),
decode_head=dict(
in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384])))

View File

@ -10,6 +10,7 @@ Collections:
- LoveDA - LoveDA
- Potsdam - Potsdam
- Vaihingen - Vaihingen
- iSAID
Paper: Paper:
URL: https://arxiv.org/abs/1908.07919 URL: https://arxiv.org/abs/1908.07919
Title: Deep High-Resolution Representation Learning for Human Pose Estimation Title: Deep High-Resolution Representation Learning for Human Pose Estimation
@ -648,3 +649,69 @@ Models:
mIoU(ms+flip): 73.52 mIoU(ms+flip): 73.52
Config: configs/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen.py Config: configs/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244-7133cb22.pth Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244-7133cb22.pth
- Name: fcn_hr18s_4x4_896x896_80k_isaid
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W18-Small
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 72.25
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 4.95
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 62.3
mIoU(ms+flip): 62.97
Config: configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603-3cc0769b.pth
- Name: fcn_hr18_4x4_896x896_80k_isaid
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W18
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 129.7
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 8.3
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 65.06
mIoU(ms+flip): 65.6
Config: configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230-49bf752e.pth
- Name: fcn_hr48_4x4_896x896_80k_isaid
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W48
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 136.24
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 16.89
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 67.8
mIoU(ms+flip): 68.53
Config: configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643-547fc420.pth

View File

@ -148,6 +148,14 @@ We support evaluation results on these two datasets using models above trained o
| PSPNet | R-50-D8 | 512x512 | 80000 | 6.14 | 30.29 | 72.36 | 73.75 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen/pspnet_r50-d8_4x4_512x512_80k_vaihingen_20211228_160355-382f8f5b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen/pspnet_r50-d8_4x4_512x512_80k_vaihingen_20211228_160355.log.json) | | PSPNet | R-50-D8 | 512x512 | 80000 | 6.14 | 30.29 | 72.36 | 73.75 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen/pspnet_r50-d8_4x4_512x512_80k_vaihingen_20211228_160355-382f8f5b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen/pspnet_r50-d8_4x4_512x512_80k_vaihingen_20211228_160355.log.json) |
| PSPNet | R-101-D8 | 512x512 | 80000 | 9.61 | 19.97 | 72.61 | 74.18 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806-8eba0a09.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806.log.json) | | PSPNet | R-101-D8 | 512x512 | 80000 | 9.61 | 19.97 | 72.61 | 74.18 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806-8eba0a09.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806.log.json) |
### iSAID
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| PSPNet | R-18-D8 | 896x896 | 80000 | 4.52 | 26.91 | 60.22 | 61.25 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526-e84c0b6a.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
| PSPNet | R-50-D8 | 896x896 | 80000 | 16.58 | 8.88 | 65.36 | 66.48 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629-1f21dc32.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629.log.json) |
Note: Note:
- `FP16` means Mixed Precision (FP16) is adopted in training. - `FP16` means Mixed Precision (FP16) is adopted in training.
- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf)

View File

@ -13,6 +13,7 @@ Collections:
- LoveDA - LoveDA
- Potsdam - Potsdam
- Vaihingen - Vaihingen
- iSAID
Paper: Paper:
URL: https://arxiv.org/abs/1612.01105 URL: https://arxiv.org/abs/1612.01105
Title: Pyramid Scene Parsing Network Title: Pyramid Scene Parsing Network
@ -942,3 +943,47 @@ Models:
mIoU(ms+flip): 74.18 mIoU(ms+flip): 74.18
Config: configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen.py Config: configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806-8eba0a09.pth Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806-8eba0a09.pth
- Name: pspnet_r18-d8_4x4_896x896_80k_isaid
In Collection: pspnet
Metadata:
backbone: R-18-D8
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 37.16
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 4.52
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 60.22
mIoU(ms+flip): 61.25
Config: configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526-e84c0b6a.pth
- Name: pspnet_r50-d8_4x4_896x896_80k_isaid
In Collection: pspnet
Metadata:
backbone: R-50-D8
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 112.61
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 16.58
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 65.36
mIoU(ms+flip): 66.48
Config: configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629-1f21dc32.pth

View File

@ -0,0 +1,9 @@
_base_ = './pspnet_r50-d8_4x4_896x896_80k_isaid.py'
model = dict(
pretrained='open-mmlab://resnet18_v1c',
backbone=dict(depth=18),
decode_head=dict(
in_channels=512,
channels=128,
),
auxiliary_head=dict(in_channels=256, channels=64))

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/isaid.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=16), auxiliary_head=dict(num_classes=16))

View File

@ -123,6 +123,21 @@ mmsegmentation
│ │ ├── ann_dir │ │ ├── ann_dir
│ │ │ ├── train │ │ │ ├── train
│ │ │ ├── val │ │ │ ├── val
│ ├── vaihingen
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
│ ├── iSAID
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ │ ├── test
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
``` ```
### Cityscapes ### Cityscapes
@ -325,3 +340,38 @@ python tools/convert_datasets/vaihingen.py /path/to/vaihingen
``` ```
In our default setting (`clip_size` =512, `stride_size`=256), it will generate 344 images for training and 398 images for validation. In our default setting (`clip_size` =512, `stride_size`=256), it will generate 344 images for training and 398 images for validation.
### iSAID
The data images could be download from [DOTA-v1.0](https://captain-whu.github.io/DOTA/dataset.html) (train/val/test)
The data annotations could be download from [iSAID](https://captain-whu.github.io/iSAID/dataset.html) (train/val)
The dataset is a Large-scale Dataset for Instance Segmentation (also have segmantic segmentation) in Aerial Images.
You may need to follow the following structure for dataset preparation after downloading iSAID dataset.
```
│ ├── iSAID
│ │ ├── train
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ │ ├── part2.zip
│ │ │ │ ├── part3.zip
│ │ │ ├── Semantic_masks
│ │ │ │ ├── images.zip
│ │ ├── val
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ ├── Semantic_masks
│ │ │ │ ├── images.zip
│ │ ├── test
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ │ ├── part2.zip
```
```shell
python tools/convert_datasets/isaid.py /path/to/iSAID
```
In our default setting (`clip_size` =512, `stride_size`=256), it will generate 33978 images for training and 11644 images for validation.

View File

@ -104,6 +104,21 @@ mmsegmentation
│ │ ├── ann_dir │ │ ├── ann_dir
│ │ │ ├── train │ │ │ ├── train
│ │ │ ├── val │ │ │ ├── val
│ ├── vaihingen
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
│ ├── iSAID
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ │ ├── test
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
``` ```
### Cityscapes ### Cityscapes
@ -266,3 +281,38 @@ python tools/convert_datasets/vaihingen.py /path/to/vaihingen
``` ```
使用我们默认的配置 (`clip_size`=512, `stride_size`=256) 将生成 344 张图片的训练集和 398 张图片的验证集。 使用我们默认的配置 (`clip_size`=512, `stride_size`=256) 将生成 344 张图片的训练集和 398 张图片的验证集。
### iSAID
iSAID 数据集(训练集/验证集/测试集)的图像可以从 [DOTA-v1.0](https://captain-whu.github.io/DOTA/dataset.html) 下载.
iSAID 数据集(训练集/验证集)的注释可以从 [iSAID](https://captain-whu.github.io/iSAID/dataset.html) 下载.
该数据集是一个大规模的实例分割(也可以用于语义分割)的遥感数据集.
下载后,在数据集转换前,您需要将数据集文件夹调整成如下格式.
```
│ ├── iSAID
│ │ ├── train
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ │ ├── part2.zip
│ │ │ │ ├── part3.zip
│ │ │ ├── Semantic_masks
│ │ │ │ ├── images.zip
│ │ ├── val
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ ├── Semantic_masks
│ │ │ │ ├── images.zip
│ │ ├── test
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ │ ├── part2.zip
```
```shell
python tools/convert_datasets/isaid.py /path/to/iSAID
```
使用我们默认的配置 (`patch_width`=896, `patch_height`=896, `overlap_area`=384) 将生成 33978 张图片的训练集和 11644 张图片的验证集。

View File

@ -111,6 +111,16 @@ def vaihingen_classes():
] ]
def isaid_classes():
"""iSAID class names for external use."""
return [
'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court',
'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle',
'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout',
'Soccer_ball_field', 'plane', 'Harbor'
]
def cityscapes_palette(): def cityscapes_palette():
"""Cityscapes palette for external use.""" """Cityscapes palette for external use."""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
@ -236,6 +246,15 @@ def vaihingen_palette():
[255, 255, 0], [255, 0, 0]] [255, 255, 0], [255, 0, 0]]
def isaid_palette():
"""iSAID palette for external use."""
return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127,
127], [0, 0, 127],
[0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191],
[0, 127, 255], [0, 100, 155]]
dataset_aliases = { dataset_aliases = {
'cityscapes': ['cityscapes'], 'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'], 'ade': ['ade', 'ade20k'],
@ -247,7 +266,8 @@ dataset_aliases = {
'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff', 'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff',
'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k', 'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k',
'coco_stuff164k' 'coco_stuff164k'
] ],
'isaid': ['isaid', 'iSAID']
} }

View File

@ -10,6 +10,7 @@ from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
RepeatDataset) RepeatDataset)
from .drive import DRIVEDataset from .drive import DRIVEDataset
from .hrf import HRFDataset from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset from .isprs import ISPRSDataset
from .loveda import LoveDADataset from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset from .night_driving import NightDrivingDataset
@ -25,5 +26,5 @@ __all__ = [
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
'ISPRSDataset', 'PotsdamDataset' 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset'
] ]

82
mmseg/datasets/isaid.py Normal file
View File

@ -0,0 +1,82 @@
import os.path as osp
import mmcv
from mmcv.utils import print_log
from ..utils import get_root_logger
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class iSAIDDataset(CustomDataset):
""" iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images
In segmentation map annotation for iSAID dataset, which is included
in 16 categories. ``reduce_zero_label`` is fixed to False. The
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'_manual1.png'.
"""
CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond',
'tennis_court', 'basketball_court', 'Ground_Track_Field',
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
'Harbor')
PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127],
[0, 127, 191], [0, 127, 255], [0, 100, 155]]
def __init__(self, **kwargs):
super(iSAIDDataset, self).__init__(
img_suffix='.png',
seg_map_suffix='.png',
ignore_index=255,
**kwargs)
assert osp.exists(self.img_dir)
def load_annotations(self,
img_dir,
img_suffix,
ann_dir,
seg_map_suffix=None,
split=None):
"""Load annotation from directory.
Args:
img_dir (str): Path to image directory
img_suffix (str): Suffix of images.
ann_dir (str|None): Path to annotation directory.
seg_map_suffix (str|None): Suffix of segmentation maps.
split (str|None): Split txt file. If split is specified, only file
with suffix in the splits will be loaded. Otherwise, all images
in img_dir/ann_dir will be loaded. Default: None
Returns:
list[dict]: All image info of dataset.
"""
img_infos = []
if split is not None:
with open(split) as f:
for line in f:
name = line.strip()
img_info = dict(filename=name + img_suffix)
if ann_dir is not None:
ann_name = name + '_instance_color_RGB'
seg_map = ann_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
else:
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
img_info = dict(filename=img)
if ann_dir is not None:
seg_img = img
seg_map = seg_img.replace(
img_suffix, '_instance_color_RGB' + seg_map_suffix)
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
return img_infos

View File

@ -16,4 +16,4 @@ default_section = THIRDPARTY
skip = *.po,*.ts,*.ipynb skip = *.po,*.ts,*.ipynb
count = count =
quiet-level = 3 quiet-level = 3
ignore-words-list = formating,sur,hist ignore-words-list = formating,sur,hist,dota

View File

@ -0,0 +1 @@
P0000_0_896_1536_2432

View File

@ -0,0 +1 @@
P0000_0_896_1024_1920

View File

@ -16,7 +16,7 @@ from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
COCOStuffDataset, ConcatDataset, CustomDataset, COCOStuffDataset, ConcatDataset, CustomDataset,
ISPRSDataset, LoveDADataset, MultiImageMixDataset, ISPRSDataset, LoveDADataset, MultiImageMixDataset,
PascalVOCDataset, PotsdamDataset, RepeatDataset, PascalVOCDataset, PotsdamDataset, RepeatDataset,
build_dataset) build_dataset, iSAIDDataset)
def test_classes(): def test_classes():
@ -25,10 +25,11 @@ def test_classes():
'pascal_voc') 'pascal_voc')
assert list( assert list(
ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k') ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff')
assert list(LoveDADataset.CLASSES) == get_classes('loveda') assert list(LoveDADataset.CLASSES) == get_classes('loveda')
assert list(PotsdamDataset.CLASSES) == get_classes('potsdam') assert list(PotsdamDataset.CLASSES) == get_classes('potsdam')
assert list(ISPRSDataset.CLASSES) == get_classes('vaihingen') assert list(ISPRSDataset.CLASSES) == get_classes('vaihingen')
assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff') assert list(iSAIDDataset.CLASSES) == get_classes('isaid')
with pytest.raises(ValueError): with pytest.raises(ValueError):
get_classes('unsupported') get_classes('unsupported')
@ -73,6 +74,7 @@ def test_palette():
assert LoveDADataset.PALETTE == get_palette('loveda') assert LoveDADataset.PALETTE == get_palette('loveda')
assert PotsdamDataset.PALETTE == get_palette('potsdam') assert PotsdamDataset.PALETTE == get_palette('potsdam')
assert COCOStuffDataset.PALETTE == get_palette('cocostuff') assert COCOStuffDataset.PALETTE == get_palette('cocostuff')
assert iSAIDDataset.PALETTE == get_palette('isaid')
with pytest.raises(ValueError): with pytest.raises(ValueError):
get_palette('unsupported') get_palette('unsupported')
@ -730,6 +732,27 @@ def test_vaihingen():
assert len(test_dataset) == 1 assert len(test_dataset) == 1
def test_isaid():
test_dataset = iSAIDDataset(
pipeline=[],
img_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'))
assert len(test_dataset) == 2
isaid_info = test_dataset.load_annotations(
img_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
img_suffix='.png',
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'),
seg_map_suffix='.png',
split=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/splits/train.txt'))
assert len(isaid_info) == 1
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__', @patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx)) MagicMock(side_effect=lambda idx: idx))

View File

@ -0,0 +1,244 @@
import argparse
import glob
import os
import os.path as osp
import shutil
import tempfile
import zipfile
import mmcv
import numpy as np
from PIL import Image
iSAID_palette = \
{
0: (0, 0, 0),
1: (0, 0, 63),
2: (0, 63, 63),
3: (0, 63, 0),
4: (0, 63, 127),
5: (0, 63, 191),
6: (0, 63, 255),
7: (0, 127, 63),
8: (0, 127, 127),
9: (0, 0, 127),
10: (0, 0, 191),
11: (0, 0, 255),
12: (0, 191, 127),
13: (0, 127, 191),
14: (0, 127, 255),
15: (0, 100, 155)
}
iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
"""RGB-color encoding to grayscale labels."""
arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
for c, i in palette.items():
m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
arr_2d[m] = i
return arr_2d
def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
img = np.asarray(Image.open(src_path).convert('RGB'))
img_H, img_W, _ = img.shape
if img_H < patch_H and img_W > patch_W:
img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0)
img_H, img_W, _ = img.shape
elif img_H > patch_H and img_W < patch_W:
img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0)
img_H, img_W, _ = img.shape
elif img_H < patch_H and img_W < patch_W:
img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0)
img_H, img_W, _ = img.shape
for x in range(0, img_W, patch_W - overlap):
for y in range(0, img_H, patch_H - overlap):
x_str = x
x_end = x + patch_W
if x_end > img_W:
diff_x = x_end - img_W
x_str -= diff_x
x_end = img_W
y_str = y
y_end = y + patch_H
if y_end > img_H:
diff_y = y_end - img_H
y_str -= diff_y
y_end = img_H
img_patch = img[y_str:y_end, x_str:x_end, :]
img_patch = Image.fromarray(img_patch.astype(np.uint8))
image = osp.splitext(
src_path.split('/')[-1])[0] + '_' + str(y_str) + '_' + str(
y_end) + '_' + str(x_str) + '_' + str(x_end) + '.png'
# print(image)
save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
img_patch.save(save_path_image)
def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
label = mmcv.imread(src_path, channel_order='rgb')
label = iSAID_convert_from_color(label)
img_H, img_W = label.shape
if img_H < patch_H and img_W > patch_W:
label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255)
img_H = patch_H
elif img_H > patch_H and img_W < patch_W:
label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255)
img_W = patch_W
elif img_H < patch_H and img_W < patch_W:
label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255)
img_H = patch_H
img_W = patch_W
for x in range(0, img_W, patch_W - overlap):
for y in range(0, img_H, patch_H - overlap):
x_str = x
x_end = x + patch_W
if x_end > img_W:
diff_x = x_end - img_W
x_str -= diff_x
x_end = img_W
y_str = y
y_end = y + patch_H
if y_end > img_H:
diff_y = y_end - img_H
y_str -= diff_y
y_end = img_H
lab_patch = label[y_str:y_end, x_str:x_end]
lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P')
image = osp.splitext(src_path.split('/')[-1])[0].split(
'_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
def parse_args():
parser = argparse.ArgumentParser(
description='Convert iSAID dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='iSAID folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--patch_width',
default=896,
type=int,
help='Width of the cropped image patch')
parser.add_argument(
'--patch_height',
default=896,
type=int,
help='Height of the cropped image patch')
parser.add_argument(
'--overlap_area', default=384, type=int, help='Overlap area')
args = parser.parse_args()
return args
def main():
args = parse_args()
dataset_path = args.dataset_path
# image patch width and height
patch_H, patch_W = args.patch_width, args.patch_height
overlap = args.overlap_area # overlap area
if args.out_dir is None:
out_dir = osp.join('data', 'iSAID')
else:
out_dir = args.out_dir
print('Making directories...')
mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
assert os.path.exists(os.path.join(dataset_path, 'train')), \
'train is not in {}'.format(dataset_path)
assert os.path.exists(os.path.join(dataset_path, 'val')), \
'val is not in {}'.format(dataset_path)
assert os.path.exists(os.path.join(dataset_path, 'test')), \
'test is not in {}'.format(dataset_path)
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for dataset_mode in ['train', 'val', 'test']:
# for dataset_mode in [ 'test']:
print('Extracting {}ing.zip...'.format(dataset_mode))
img_zipp_list = glob.glob(
os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
print('Find the data', img_zipp_list)
for img_zipp in img_zipp_list:
zip_file = zipfile.ZipFile(img_zipp)
zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
src_path_list = glob.glob(
os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
src_prog_bar = mmcv.ProgressBar(len(src_path_list))
for i, img_path in enumerate(src_path_list):
if dataset_mode != 'test':
slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
patch_W, overlap)
else:
shutil.move(img_path,
os.path.join(out_dir, 'img_dir', dataset_mode))
src_prog_bar.update()
if dataset_mode != 'test':
label_zipp_list = glob.glob(
os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
'*.zip'))
for label_zipp in label_zipp_list:
zip_file = zipfile.ZipFile(label_zipp)
zip_file.extractall(
os.path.join(tmp_dir, dataset_mode, 'lab'))
lab_path_list = glob.glob(
os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
'*.png'))
lab_prog_bar = mmcv.ProgressBar(len(lab_path_list))
for i, lab_path in enumerate(lab_path_list):
slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
patch_W, overlap)
lab_prog_bar.update()
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()