[Feature] Support iSAID aerial dataset. (#1115)

* support iSAID aerial dataset

* Update and rename docs/dataset_prepare.md to 博士/dataset_prepare.md

* Update dataset_prepare.md

* fix typo

* fix typo

* fix typo

* remove imgviz

* fix wrong order in annotation name

* upload models&logs

* upload models&logs

* add load_annotations

* fix unittest coverage

* fix unittest coverage

* fix correct crop size in config

* fix iSAID unit test

* fix iSAID unit test

* fix typos

* fix wrong crop size in readme

* use smaller figure as test data

* add smaller dataset in test data

* add blank in docs

* use 0 bytes pseudo data

* add footnote and comments for crop size

* change iSAID to isaid and add default value in it

* change iSAID to isaid in _base_

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
This commit is contained in:
Waterman0524 2022-02-17 19:07:32 +08:00 committed by GitHub
parent 9522b4fc97
commit 4f4e7728b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 783 additions and 6 deletions

View File

@ -138,6 +138,7 @@ Supported datasets:
- [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#loveda)
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-potsdam)
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-vaihingen)
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isaid)
## Installation

View File

@ -137,6 +137,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#loveda)
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-potsdam)
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-vaihingen)
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isaid)
## 安装

View File

@ -0,0 +1,62 @@
# dataset settings
dataset_type = 'iSAIDDataset'
data_root = 'data/iSAID'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
"""
This crop_size setting is followed by the implementation of
`PointFlow: Flowing Semantics Through Points for Aerial Image
Segmentation <https://arxiv.org/pdf/2103.06564.pdf>`_.
"""
crop_size = (896, 896)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(896, 896),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline))

View File

@ -114,8 +114,16 @@ Spatial pyramid pooling module or encode-decoder structure are used in deep neur
| DeepLabV3+ | R-50-D8 | 512x512 | 80000 | 7.36 | 26.91 | 73.97 | 75.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816-5040938d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816.log.json) |
| DeepLabV3+ | R-101-D8 | 512x512 | 80000 | 10.83 | 18.59 | 73.06 | 74.14 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816.log.json) |
### iSAID
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| DeepLabV3+ | R-18-D8 | 896x896 | 80000 | 6.19 | 24.81 | 61.35 | 62.61 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526-7059991d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
| DeepLabV3+ | R-50-D8 | 896x896 | 80000 | 21.45 | 8.42 | 67.06 | 68.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
Note:
- `D-8`/`D-16` here corresponding to the output stride 8/16 setting for DeepLab series.
- `MG-124` stands for multi-grid dilation in the last stage of ResNet.
- `FP16` means Mixed Precision (FP16) is adopted in training.
- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf)

View File

@ -10,6 +10,7 @@ Collections:
- LoveDA
- Potsdam
- Vaihingen
- iSAID
Paper:
URL: https://arxiv.org/abs/1802.02611
Title: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation
@ -803,3 +804,47 @@ Models:
mIoU(ms+flip): 74.14
Config: configs/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth
- Name: deeplabv3plus_r18-d8_4x4_896x896_80k_isaid
In Collection: deeplabv3plus
Metadata:
backbone: R-18-D8
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 40.31
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 6.19
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 61.35
mIoU(ms+flip): 62.61
Config: configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526-7059991d.pth
- Name: deeplabv3plus_r50-d8_4x4_896x896_80k_isaid
In Collection: deeplabv3plus
Metadata:
backbone: R-50-D8
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 118.76
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 21.45
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 67.06
mIoU(ms+flip): 68.02
Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth

View File

@ -0,0 +1,11 @@
_base_ = './deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py'
model = dict(
pretrained='open-mmlab://resnet18_v1c',
backbone=dict(depth=18),
decode_head=dict(
c1_in_channels=64,
c1_channels=12,
in_channels=512,
channels=128,
),
auxiliary_head=dict(in_channels=256, channels=64))

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/deeplabv3plus_r50-d8.py', '../_base_/datasets/isaid.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=16), auxiliary_head=dict(num_classes=16))

View File

@ -107,3 +107,15 @@ High-resolution representations are essential for position-sensitive vision prob
| FCN | HRNetV2p-W18-Small | 512x512 | 80000 | 1.58 | 38.11 | 71.81 | 73.1 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen/fcn_hr18s_4x4_512x512_80k_vaihingen_20211231_230909-b23aae02.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen/fcn_hr18s_4x4_512x512_80k_vaihingen_20211231_230909.log.json) |
| FCN | HRNetV2p-W18 | 512x512 | 80000 | 2.76 | 19.55 | 72.57 | 74.09 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen/fcn_hr18_4x4_512x512_80k_vaihingen_20211231_231216-2ec3ae8a.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen/fcn_hr18_4x4_512x512_80k_vaihingen_20211231_231216.log.json) |
| FCN | HRNetV2p-W48 | 512x512 | 80000 | 6.20 | 17.25 | 72.50 | 73.52 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244-7133cb22.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244.log.json) |
### iSAID
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| FCN | HRNetV2p-W18-Small | 896x896 | 80000 | 4.95 | 13.84 | 62.30 | 62.97 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603-3cc0769b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603.log.json) |
| FCN | HRNetV2p-W18 | 896x896 | 80000 | 8.30 | 7.71 | 65.06 | 65.60 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230-49bf752e.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230.log.json) |
| FCN | HRNetV2p-W48 | 896x896 | 80000 | 16.89 | 7.34 | 67.80 | 68.53 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643-547fc420.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643.log.json) |
Note:
- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf)

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/fcn_hr18.py', '../_base_/datasets/isaid.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(decode_head=dict(num_classes=16))

View File

@ -0,0 +1,9 @@
_base_ = './fcn_hr18_4x4_896x896_80k_isaid.py'
model = dict(
pretrained='open-mmlab://msra/hrnetv2_w18_small',
backbone=dict(
extra=dict(
stage1=dict(num_blocks=(2, )),
stage2=dict(num_blocks=(2, 2)),
stage3=dict(num_modules=3, num_blocks=(2, 2, 2)),
stage4=dict(num_modules=2, num_blocks=(2, 2, 2, 2)))))

View File

@ -0,0 +1,10 @@
_base_ = './fcn_hr18_4x4_896x896_80k_isaid.py'
model = dict(
pretrained='open-mmlab://msra/hrnetv2_w48',
backbone=dict(
extra=dict(
stage2=dict(num_channels=(48, 96)),
stage3=dict(num_channels=(48, 96, 192)),
stage4=dict(num_channels=(48, 96, 192, 384)))),
decode_head=dict(
in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384])))

View File

@ -10,6 +10,7 @@ Collections:
- LoveDA
- Potsdam
- Vaihingen
- iSAID
Paper:
URL: https://arxiv.org/abs/1908.07919
Title: Deep High-Resolution Representation Learning for Human Pose Estimation
@ -648,3 +649,69 @@ Models:
mIoU(ms+flip): 73.52
Config: configs/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244-7133cb22.pth
- Name: fcn_hr18s_4x4_896x896_80k_isaid
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W18-Small
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 72.25
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 4.95
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 62.3
mIoU(ms+flip): 62.97
Config: configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603-3cc0769b.pth
- Name: fcn_hr18_4x4_896x896_80k_isaid
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W18
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 129.7
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 8.3
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 65.06
mIoU(ms+flip): 65.6
Config: configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230-49bf752e.pth
- Name: fcn_hr48_4x4_896x896_80k_isaid
In Collection: hrnet
Metadata:
backbone: HRNetV2p-W48
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 136.24
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 16.89
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 67.8
mIoU(ms+flip): 68.53
Config: configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643-547fc420.pth

View File

@ -148,6 +148,14 @@ We support evaluation results on these two datasets using models above trained o
| PSPNet | R-50-D8 | 512x512 | 80000 | 6.14 | 30.29 | 72.36 | 73.75 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen/pspnet_r50-d8_4x4_512x512_80k_vaihingen_20211228_160355-382f8f5b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen/pspnet_r50-d8_4x4_512x512_80k_vaihingen_20211228_160355.log.json) |
| PSPNet | R-101-D8 | 512x512 | 80000 | 9.61 | 19.97 | 72.61 | 74.18 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806-8eba0a09.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806.log.json) |
### iSAID
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| PSPNet | R-18-D8 | 896x896 | 80000 | 4.52 | 26.91 | 60.22 | 61.25 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526-e84c0b6a.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
| PSPNet | R-50-D8 | 896x896 | 80000 | 16.58 | 8.88 | 65.36 | 66.48 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629-1f21dc32.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629.log.json) |
Note:
- `FP16` means Mixed Precision (FP16) is adopted in training.
- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf)

View File

@ -13,6 +13,7 @@ Collections:
- LoveDA
- Potsdam
- Vaihingen
- iSAID
Paper:
URL: https://arxiv.org/abs/1612.01105
Title: Pyramid Scene Parsing Network
@ -942,3 +943,47 @@ Models:
mIoU(ms+flip): 74.18
Config: configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806-8eba0a09.pth
- Name: pspnet_r18-d8_4x4_896x896_80k_isaid
In Collection: pspnet
Metadata:
backbone: R-18-D8
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 37.16
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 4.52
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 60.22
mIoU(ms+flip): 61.25
Config: configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526-e84c0b6a.pth
- Name: pspnet_r50-d8_4x4_896x896_80k_isaid
In Collection: pspnet
Metadata:
backbone: R-50-D8
crop size: (896,896)
lr schd: 80000
inference time (ms/im):
- value: 112.61
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (896,896)
Training Memory (GB): 16.58
Results:
- Task: Semantic Segmentation
Dataset: iSAID
Metrics:
mIoU: 65.36
mIoU(ms+flip): 66.48
Config: configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629-1f21dc32.pth

View File

@ -0,0 +1,9 @@
_base_ = './pspnet_r50-d8_4x4_896x896_80k_isaid.py'
model = dict(
pretrained='open-mmlab://resnet18_v1c',
backbone=dict(depth=18),
decode_head=dict(
in_channels=512,
channels=128,
),
auxiliary_head=dict(in_channels=256, channels=64))

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/isaid.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=16), auxiliary_head=dict(num_classes=16))

View File

@ -123,6 +123,21 @@ mmsegmentation
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
│ ├── vaihingen
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
│ ├── iSAID
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ │ ├── test
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
```
### Cityscapes
@ -325,3 +340,38 @@ python tools/convert_datasets/vaihingen.py /path/to/vaihingen
```
In our default setting (`clip_size` =512, `stride_size`=256), it will generate 344 images for training and 398 images for validation.
### iSAID
The data images could be download from [DOTA-v1.0](https://captain-whu.github.io/DOTA/dataset.html) (train/val/test)
The data annotations could be download from [iSAID](https://captain-whu.github.io/iSAID/dataset.html) (train/val)
The dataset is a Large-scale Dataset for Instance Segmentation (also have segmantic segmentation) in Aerial Images.
You may need to follow the following structure for dataset preparation after downloading iSAID dataset.
```
│ ├── iSAID
│ │ ├── train
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ │ ├── part2.zip
│ │ │ │ ├── part3.zip
│ │ │ ├── Semantic_masks
│ │ │ │ ├── images.zip
│ │ ├── val
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ ├── Semantic_masks
│ │ │ │ ├── images.zip
│ │ ├── test
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ │ ├── part2.zip
```
```shell
python tools/convert_datasets/isaid.py /path/to/iSAID
```
In our default setting (`clip_size` =512, `stride_size`=256), it will generate 33978 images for training and 11644 images for validation.

View File

@ -104,6 +104,21 @@ mmsegmentation
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
│ ├── vaihingen
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
│ ├── iSAID
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ ├── val
│ │ │ ├── test
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ ├── val
```
### Cityscapes
@ -265,4 +280,39 @@ python tools/convert_datasets/potsdam.py /path/to/potsdam
python tools/convert_datasets/vaihingen.py /path/to/vaihingen
```
使用我们默认的配置 (`clip_size` =512, `stride_size`=256) 将生成 344 张图片的训练集和 398 张图片的验证集。
使用我们默认的配置 (`clip_size`=512, `stride_size`=256) 将生成 344 张图片的训练集和 398 张图片的验证集。
### iSAID
iSAID 数据集(训练集/验证集/测试集)的图像可以从 [DOTA-v1.0](https://captain-whu.github.io/DOTA/dataset.html) 下载.
iSAID 数据集(训练集/验证集)的注释可以从 [iSAID](https://captain-whu.github.io/iSAID/dataset.html) 下载.
该数据集是一个大规模的实例分割(也可以用于语义分割)的遥感数据集.
下载后,在数据集转换前,您需要将数据集文件夹调整成如下格式.
```
│ ├── iSAID
│ │ ├── train
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ │ ├── part2.zip
│ │ │ │ ├── part3.zip
│ │ │ ├── Semantic_masks
│ │ │ │ ├── images.zip
│ │ ├── val
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ ├── Semantic_masks
│ │ │ │ ├── images.zip
│ │ ├── test
│ │ │ ├── images
│ │ │ │ ├── part1.zip
│ │ │ │ ├── part2.zip
```
```shell
python tools/convert_datasets/isaid.py /path/to/iSAID
```
使用我们默认的配置 (`patch_width`=896, `patch_height`=896, `overlap_area`=384) 将生成 33978 张图片的训练集和 11644 张图片的验证集。

View File

@ -111,6 +111,16 @@ def vaihingen_classes():
]
def isaid_classes():
"""iSAID class names for external use."""
return [
'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court',
'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle',
'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout',
'Soccer_ball_field', 'plane', 'Harbor'
]
def cityscapes_palette():
"""Cityscapes palette for external use."""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
@ -236,6 +246,15 @@ def vaihingen_palette():
[255, 255, 0], [255, 0, 0]]
def isaid_palette():
"""iSAID palette for external use."""
return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127,
127], [0, 0, 127],
[0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191],
[0, 127, 255], [0, 100, 155]]
dataset_aliases = {
'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'],
@ -247,7 +266,8 @@ dataset_aliases = {
'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff',
'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k',
'coco_stuff164k'
]
],
'isaid': ['isaid', 'iSAID']
}

View File

@ -10,6 +10,7 @@ from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
RepeatDataset)
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
@ -25,5 +26,5 @@ __all__ = [
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
'ISPRSDataset', 'PotsdamDataset'
'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset'
]

82
mmseg/datasets/isaid.py Normal file
View File

@ -0,0 +1,82 @@
import os.path as osp
import mmcv
from mmcv.utils import print_log
from ..utils import get_root_logger
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class iSAIDDataset(CustomDataset):
""" iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images
In segmentation map annotation for iSAID dataset, which is included
in 16 categories. ``reduce_zero_label`` is fixed to False. The
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'_manual1.png'.
"""
CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond',
'tennis_court', 'basketball_court', 'Ground_Track_Field',
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
'Harbor')
PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127],
[0, 127, 191], [0, 127, 255], [0, 100, 155]]
def __init__(self, **kwargs):
super(iSAIDDataset, self).__init__(
img_suffix='.png',
seg_map_suffix='.png',
ignore_index=255,
**kwargs)
assert osp.exists(self.img_dir)
def load_annotations(self,
img_dir,
img_suffix,
ann_dir,
seg_map_suffix=None,
split=None):
"""Load annotation from directory.
Args:
img_dir (str): Path to image directory
img_suffix (str): Suffix of images.
ann_dir (str|None): Path to annotation directory.
seg_map_suffix (str|None): Suffix of segmentation maps.
split (str|None): Split txt file. If split is specified, only file
with suffix in the splits will be loaded. Otherwise, all images
in img_dir/ann_dir will be loaded. Default: None
Returns:
list[dict]: All image info of dataset.
"""
img_infos = []
if split is not None:
with open(split) as f:
for line in f:
name = line.strip()
img_info = dict(filename=name + img_suffix)
if ann_dir is not None:
ann_name = name + '_instance_color_RGB'
seg_map = ann_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
else:
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
img_info = dict(filename=img)
if ann_dir is not None:
seg_img = img
seg_map = seg_img.replace(
img_suffix, '_instance_color_RGB' + seg_map_suffix)
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
return img_infos

View File

@ -16,4 +16,4 @@ default_section = THIRDPARTY
skip = *.po,*.ts,*.ipynb
count =
quiet-level = 3
ignore-words-list = formating,sur,hist
ignore-words-list = formating,sur,hist,dota

View File

@ -0,0 +1 @@
P0000_0_896_1536_2432

View File

@ -0,0 +1 @@
P0000_0_896_1024_1920

View File

@ -16,7 +16,7 @@ from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
COCOStuffDataset, ConcatDataset, CustomDataset,
ISPRSDataset, LoveDADataset, MultiImageMixDataset,
PascalVOCDataset, PotsdamDataset, RepeatDataset,
build_dataset)
build_dataset, iSAIDDataset)
def test_classes():
@ -25,10 +25,11 @@ def test_classes():
'pascal_voc')
assert list(
ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff')
assert list(LoveDADataset.CLASSES) == get_classes('loveda')
assert list(PotsdamDataset.CLASSES) == get_classes('potsdam')
assert list(ISPRSDataset.CLASSES) == get_classes('vaihingen')
assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff')
assert list(iSAIDDataset.CLASSES) == get_classes('isaid')
with pytest.raises(ValueError):
get_classes('unsupported')
@ -73,6 +74,7 @@ def test_palette():
assert LoveDADataset.PALETTE == get_palette('loveda')
assert PotsdamDataset.PALETTE == get_palette('potsdam')
assert COCOStuffDataset.PALETTE == get_palette('cocostuff')
assert iSAIDDataset.PALETTE == get_palette('isaid')
with pytest.raises(ValueError):
get_palette('unsupported')
@ -730,6 +732,27 @@ def test_vaihingen():
assert len(test_dataset) == 1
def test_isaid():
test_dataset = iSAIDDataset(
pipeline=[],
img_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'))
assert len(test_dataset) == 2
isaid_info = test_dataset.load_annotations(
img_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
img_suffix='.png',
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'),
seg_map_suffix='.png',
split=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/splits/train.txt'))
assert len(isaid_info) == 1
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))

View File

@ -0,0 +1,244 @@
import argparse
import glob
import os
import os.path as osp
import shutil
import tempfile
import zipfile
import mmcv
import numpy as np
from PIL import Image
iSAID_palette = \
{
0: (0, 0, 0),
1: (0, 0, 63),
2: (0, 63, 63),
3: (0, 63, 0),
4: (0, 63, 127),
5: (0, 63, 191),
6: (0, 63, 255),
7: (0, 127, 63),
8: (0, 127, 127),
9: (0, 0, 127),
10: (0, 0, 191),
11: (0, 0, 255),
12: (0, 191, 127),
13: (0, 127, 191),
14: (0, 127, 255),
15: (0, 100, 155)
}
iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
"""RGB-color encoding to grayscale labels."""
arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
for c, i in palette.items():
m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
arr_2d[m] = i
return arr_2d
def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
img = np.asarray(Image.open(src_path).convert('RGB'))
img_H, img_W, _ = img.shape
if img_H < patch_H and img_W > patch_W:
img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0)
img_H, img_W, _ = img.shape
elif img_H > patch_H and img_W < patch_W:
img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0)
img_H, img_W, _ = img.shape
elif img_H < patch_H and img_W < patch_W:
img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0)
img_H, img_W, _ = img.shape
for x in range(0, img_W, patch_W - overlap):
for y in range(0, img_H, patch_H - overlap):
x_str = x
x_end = x + patch_W
if x_end > img_W:
diff_x = x_end - img_W
x_str -= diff_x
x_end = img_W
y_str = y
y_end = y + patch_H
if y_end > img_H:
diff_y = y_end - img_H
y_str -= diff_y
y_end = img_H
img_patch = img[y_str:y_end, x_str:x_end, :]
img_patch = Image.fromarray(img_patch.astype(np.uint8))
image = osp.splitext(
src_path.split('/')[-1])[0] + '_' + str(y_str) + '_' + str(
y_end) + '_' + str(x_str) + '_' + str(x_end) + '.png'
# print(image)
save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
img_patch.save(save_path_image)
def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
label = mmcv.imread(src_path, channel_order='rgb')
label = iSAID_convert_from_color(label)
img_H, img_W = label.shape
if img_H < patch_H and img_W > patch_W:
label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255)
img_H = patch_H
elif img_H > patch_H and img_W < patch_W:
label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255)
img_W = patch_W
elif img_H < patch_H and img_W < patch_W:
label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255)
img_H = patch_H
img_W = patch_W
for x in range(0, img_W, patch_W - overlap):
for y in range(0, img_H, patch_H - overlap):
x_str = x
x_end = x + patch_W
if x_end > img_W:
diff_x = x_end - img_W
x_str -= diff_x
x_end = img_W
y_str = y
y_end = y + patch_H
if y_end > img_H:
diff_y = y_end - img_H
y_str -= diff_y
y_end = img_H
lab_patch = label[y_str:y_end, x_str:x_end]
lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P')
image = osp.splitext(src_path.split('/')[-1])[0].split(
'_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
def parse_args():
parser = argparse.ArgumentParser(
description='Convert iSAID dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='iSAID folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--patch_width',
default=896,
type=int,
help='Width of the cropped image patch')
parser.add_argument(
'--patch_height',
default=896,
type=int,
help='Height of the cropped image patch')
parser.add_argument(
'--overlap_area', default=384, type=int, help='Overlap area')
args = parser.parse_args()
return args
def main():
args = parse_args()
dataset_path = args.dataset_path
# image patch width and height
patch_H, patch_W = args.patch_width, args.patch_height
overlap = args.overlap_area # overlap area
if args.out_dir is None:
out_dir = osp.join('data', 'iSAID')
else:
out_dir = args.out_dir
print('Making directories...')
mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
assert os.path.exists(os.path.join(dataset_path, 'train')), \
'train is not in {}'.format(dataset_path)
assert os.path.exists(os.path.join(dataset_path, 'val')), \
'val is not in {}'.format(dataset_path)
assert os.path.exists(os.path.join(dataset_path, 'test')), \
'test is not in {}'.format(dataset_path)
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
for dataset_mode in ['train', 'val', 'test']:
# for dataset_mode in [ 'test']:
print('Extracting {}ing.zip...'.format(dataset_mode))
img_zipp_list = glob.glob(
os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
print('Find the data', img_zipp_list)
for img_zipp in img_zipp_list:
zip_file = zipfile.ZipFile(img_zipp)
zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
src_path_list = glob.glob(
os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
src_prog_bar = mmcv.ProgressBar(len(src_path_list))
for i, img_path in enumerate(src_path_list):
if dataset_mode != 'test':
slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
patch_W, overlap)
else:
shutil.move(img_path,
os.path.join(out_dir, 'img_dir', dataset_mode))
src_prog_bar.update()
if dataset_mode != 'test':
label_zipp_list = glob.glob(
os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
'*.zip'))
for label_zipp in label_zipp_list:
zip_file = zipfile.ZipFile(label_zipp)
zip_file.extractall(
os.path.join(tmp_dir, dataset_mode, 'lab'))
lab_path_list = glob.glob(
os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
'*.png'))
lab_prog_bar = mmcv.ProgressBar(len(lab_path_list))
for i, lab_path in enumerate(lab_path_list):
slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
patch_W, overlap)
lab_prog_bar.update()
print('Removing the temporary files...')
print('Done!')
if __name__ == '__main__':
main()