diff --git a/configs/_base_/datasets/potsdam.py b/configs/_base_/datasets/potsdam.py new file mode 100644 index 000000000..f74c4a56c --- /dev/null +++ b/configs/_base_/datasets/potsdam.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'PotsdamDataset' +data_root = 'data/potsdam' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(512, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/train', + ann_dir='ann_dir/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/val', + ann_dir='ann_dir/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/val', + ann_dir='ann_dir/val', + pipeline=test_pipeline)) diff --git a/configs/deeplabv3plus/README.md b/configs/deeplabv3plus/README.md index 16b4a96c9..b6063662d 100644 --- a/configs/deeplabv3plus/README.md +++ b/configs/deeplabv3plus/README.md @@ -104,6 +104,14 @@ Spatial pyramid pooling module or encode-decoder structure are used in deep neur | DeepLabV3+ | R-50-D8 | 512x512 | 80000 | 7.37 | 6.00 | 50.99 | 50.65 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_loveda/deeplabv3plus_r50-d8_512x512_80k_loveda_20211105_080442-f0720392.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_loveda/deeplabv3plus_r50-d8_512x512_80k_loveda_20211105_080442.log.json) | | DeepLabV3+ | R-101-D8 | 512x512 | 80000 | 10.84 | 4.33 | 51.47 | 51.32 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda/deeplabv3plus_r101-d8_512x512_80k_loveda_20211105_110759-4c1f297e.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda/deeplabv3plus_r101-d8_512x512_80k_loveda_20211105_110759.log.json) | +### Potsdam + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| DeepLabV3+ | R-18-D8 | 512x512 | 80000 | 1.91 | 81.68 | 77.09 | 78.44 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam/deeplabv3plus_r18-d8_512x512_80k_potsdam_20211219_020601-75fd5bc3.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam/deeplabv3plus_r18-d8_512x512_80k_potsdam_20211219_020601.log.json) | +| DeepLabV3+ | R-50-D8 | 512x512 | 80000 | 7.36 | 26.44 | 78.33 | 79.27 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam/deeplabv3plus_r50-d8_512x512_80k_potsdam_20211219_031508-7e7a2b24.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam/deeplabv3plus_r50-d8_512x512_80k_potsdam_20211219_031508.log.json) | +| DeepLabV3+ | R-101-D8 | 512x512 | 80000 | 10.83 | 17.56 | 78.7 | 79.47 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam/deeplabv3plus_r101-d8_512x512_80k_potsdam_20211219_031508-8b112708.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam/deeplabv3plus_r101-d8_512x512_80k_potsdam_20211219_031508.log.json) | + Note: - `FP16` means Mixed Precision (FP16) is adopted in training. diff --git a/configs/deeplabv3plus/deeplabv3plus.yml b/configs/deeplabv3plus/deeplabv3plus.yml index 43e7a3c14..997e2607d 100644 --- a/configs/deeplabv3plus/deeplabv3plus.yml +++ b/configs/deeplabv3plus/deeplabv3plus.yml @@ -8,6 +8,7 @@ Collections: - Pascal Context - Pascal Context 59 - LoveDA + - Potsdam Paper: URL: https://arxiv.org/abs/1802.02611 Title: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation @@ -669,3 +670,69 @@ Models: mIoU(ms+flip): 51.32 Config: configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda.py Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_loveda/deeplabv3plus_r101-d8_512x512_80k_loveda_20211105_110759-4c1f297e.pth +- Name: deeplabv3plus_r18-d8_512x512_80k_potsdam + In Collection: deeplabv3plus + Metadata: + backbone: R-18-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 12.24 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 1.91 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 77.09 + mIoU(ms+flip): 78.44 + Config: configs/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam/deeplabv3plus_r18-d8_512x512_80k_potsdam_20211219_020601-75fd5bc3.pth +- Name: deeplabv3plus_r50-d8_512x512_80k_potsdam + In Collection: deeplabv3plus + Metadata: + backbone: R-50-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 37.82 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 7.36 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 78.33 + mIoU(ms+flip): 79.27 + Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam/deeplabv3plus_r50-d8_512x512_80k_potsdam_20211219_031508-7e7a2b24.pth +- Name: deeplabv3plus_r101-d8_512x512_80k_potsdam + In Collection: deeplabv3plus + Metadata: + backbone: R-101-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 56.95 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 10.83 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 78.7 + mIoU(ms+flip): 79.47 + Config: configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam/deeplabv3plus_r101-d8_512x512_80k_potsdam_20211219_031508-8b112708.pth diff --git a/configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam.py b/configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam.py new file mode 100644 index 000000000..d89491440 --- /dev/null +++ b/configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_80k_potsdam.py @@ -0,0 +1,2 @@ +_base_ = './deeplabv3plus_r50-d8_512x512_80k_potsdam.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam.py b/configs/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam.py new file mode 100644 index 000000000..ffb20df72 --- /dev/null +++ b/configs/deeplabv3plus/deeplabv3plus_r18-d8_512x512_80k_potsdam.py @@ -0,0 +1,11 @@ +_base_ = './deeplabv3plus_r50-d8_512x512_80k_potsdam.py' +model = dict( + pretrained='open-mmlab://resnet18_v1c', + backbone=dict(depth=18), + decode_head=dict( + c1_in_channels=64, + c1_channels=12, + in_channels=512, + channels=128, + ), + auxiliary_head=dict(in_channels=256, channels=64)) diff --git a/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam.py b/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam.py new file mode 100644 index 000000000..d5ae03fd5 --- /dev/null +++ b/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_potsdam.py @@ -0,0 +1,7 @@ +_base_ = [ + '../_base_/models/deeplabv3plus_r50-d8.py', + '../_base_/datasets/potsdam.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] +model = dict( + decode_head=dict(num_classes=6), auxiliary_head=dict(num_classes=6)) diff --git a/configs/hrnet/README.md b/configs/hrnet/README.md index 345d175df..9b7d98495 100644 --- a/configs/hrnet/README.md +++ b/configs/hrnet/README.md @@ -92,3 +92,11 @@ High-resolution representations are essential for position-sensitive vision prob | FCN | HRNetV2p-W18-Small | 512x512 | 80000 | 1.59 | 24.87 | 49.28 | 49.42 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_loveda/fcn_hr18s_512x512_80k_loveda_20211210_203228-60a86a7a.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_loveda/fcn_hr18s_512x512_80k_loveda_20211210_203228.log.json) | | FCN | HRNetV2p-W18 | 512x512 | 80000 | 2.76 | 12.92 | 50.81 | 50.95 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_loveda/fcn_hr18_512x512_80k_loveda_20211210_203952-93d9c3b3.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_loveda/fcn_hr18_512x512_80k_loveda_20211210_203952.log.json) | | FCN | HRNetV2p-W48 | 512x512 | 80000 | 6.20 | 9.61 | 51.42 | 51.64 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_loveda/fcn_hr48_512x512_80k_loveda_20211211_044756-67072f55.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_loveda/fcn_hr48_512x512_80k_loveda_20211211_044756.log.json) | + +### Potsdam + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| FCN | HRNetV2p-W18-Small | 512x512 | 80000 | 1.58 | 36.00 | 77.64 | 78.8 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_potsdam/fcn_hr18s_512x512_80k_potsdam_20211218_205517-ba32af63.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_potsdam/fcn_hr18s_512x512_80k_potsdam_20211218_205517.log.json) | +| FCN | HRNetV2p-W18 | 512x512 | 80000 | 2.76 | 19.25 | 78.26 | 79.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_potsdam/fcn_hr18_512x512_80k_potsdam_20211218_205517-5d0387ad.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_potsdam/fcn_hr18_512x512_80k_potsdam_20211218_205517.log.json) | +| FCN | HRNetV2p-W48 | 512x512 | 80000 | 6.20 | 16.42 | 78.39 | 79.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_potsdam/fcn_hr48_512x512_80k_potsdam_20211219_020601-97434c78.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_potsdam/fcn_hr48_512x512_80k_potsdam_20211219_020601.log.json) | diff --git a/configs/hrnet/fcn_hr18_512x512_80k_potsdam.py b/configs/hrnet/fcn_hr18_512x512_80k_potsdam.py new file mode 100644 index 000000000..043017f91 --- /dev/null +++ b/configs/hrnet/fcn_hr18_512x512_80k_potsdam.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/fcn_hr18.py', '../_base_/datasets/potsdam.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +model = dict(decode_head=dict(num_classes=6)) diff --git a/configs/hrnet/fcn_hr18s_512x512_80k_potsdam.py b/configs/hrnet/fcn_hr18s_512x512_80k_potsdam.py new file mode 100644 index 000000000..05551271a --- /dev/null +++ b/configs/hrnet/fcn_hr18s_512x512_80k_potsdam.py @@ -0,0 +1,9 @@ +_base_ = './fcn_hr18_512x512_80k_potsdam.py' +model = dict( + pretrained='open-mmlab://msra/hrnetv2_w18_small', + backbone=dict( + extra=dict( + stage1=dict(num_blocks=(2, )), + stage2=dict(num_blocks=(2, 2)), + stage3=dict(num_modules=3, num_blocks=(2, 2, 2)), + stage4=dict(num_modules=2, num_blocks=(2, 2, 2, 2))))) diff --git a/configs/hrnet/fcn_hr48_512x512_80k_potsdam.py b/configs/hrnet/fcn_hr48_512x512_80k_potsdam.py new file mode 100644 index 000000000..608fee387 --- /dev/null +++ b/configs/hrnet/fcn_hr48_512x512_80k_potsdam.py @@ -0,0 +1,10 @@ +_base_ = './fcn_hr18_512x512_80k_potsdam.py' +model = dict( + pretrained='open-mmlab://msra/hrnetv2_w48', + backbone=dict( + extra=dict( + stage2=dict(num_channels=(48, 96)), + stage3=dict(num_channels=(48, 96, 192)), + stage4=dict(num_channels=(48, 96, 192, 384)))), + decode_head=dict( + in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384]))) diff --git a/configs/hrnet/hrnet.yml b/configs/hrnet/hrnet.yml index 5cef15ef2..ae6c2665b 100644 --- a/configs/hrnet/hrnet.yml +++ b/configs/hrnet/hrnet.yml @@ -8,6 +8,7 @@ Collections: - Pascal Context - Pascal Context 59 - LoveDA + - Potsdam Paper: URL: https://arxiv.org/abs/1908.07919 Title: Deep High-Resolution Representation Learning for Human Pose Estimation @@ -514,3 +515,69 @@ Models: mIoU(ms+flip): 51.64 Config: configs/hrnet/fcn_hr48_512x512_80k_loveda.py Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_loveda/fcn_hr48_512x512_80k_loveda_20211211_044756-67072f55.pth +- Name: fcn_hr18s_512x512_80k_potsdam + In Collection: hrnet + Metadata: + backbone: HRNetV2p-W18-Small + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 27.78 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 1.58 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 77.64 + mIoU(ms+flip): 78.8 + Config: configs/hrnet/fcn_hr18s_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_512x512_80k_potsdam/fcn_hr18s_512x512_80k_potsdam_20211218_205517-ba32af63.pth +- Name: fcn_hr18_512x512_80k_potsdam + In Collection: hrnet + Metadata: + backbone: HRNetV2p-W18 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 51.95 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 2.76 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 78.26 + mIoU(ms+flip): 79.24 + Config: configs/hrnet/fcn_hr18_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_512x512_80k_potsdam/fcn_hr18_512x512_80k_potsdam_20211218_205517-5d0387ad.pth +- Name: fcn_hr48_512x512_80k_potsdam + In Collection: hrnet + Metadata: + backbone: HRNetV2p-W48 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 60.9 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 6.2 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 78.39 + mIoU(ms+flip): 79.34 + Config: configs/hrnet/fcn_hr48_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_512x512_80k_potsdam/fcn_hr48_512x512_80k_potsdam_20211219_020601-97434c78.pth diff --git a/configs/pspnet/README.md b/configs/pspnet/README.md index 15e8ad3c0..367abff32 100644 --- a/configs/pspnet/README.md +++ b/configs/pspnet/README.md @@ -133,6 +133,14 @@ We support evaluation results on these two datasets using models above trained o | PSPNet | R-50-D8 | 512x512 | 80000 | 6.14 | 6.60 | 50.46 | 50.19 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x512_80k_loveda/pspnet_r50-d8_512x512_80k_loveda_20211104_155728-88610f9f.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x512_80k_loveda/pspnet_r50-d8_512x512_80k_loveda_20211104_155728.log.json) | | PSPNet | R-101-D8 | 512x512 | 80000 | 9.61 | 4.58 | 51.86 | 51.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_512x512_80k_loveda.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x512_80k_loveda/pspnet_r101-d8_512x512_80k_loveda_20211104_153212-1c06c6a8.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x512_80k_loveda/pspnet_r101-d8_512x512_80k_loveda_20211104_153212.log.json) | +### Potsdam + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| PSPNet | R-18-D8 | 512x512 | 80000 | 1.50 | 85.12 | 77.09 | 78.30 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam/pspnet_r18-d8_4x4_512x512_80k_potsdam_20211220_125612-7cd046e1.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam/pspnet_r18-d8_4x4_512x512_80k_potsdam_20211220_125612.log.json) | +| PSPNet | R-50-D8 | 512x512 | 80000 | 6.14 | 30.21 | 78.12 | 78.98 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam/pspnet_r50-d8_4x4_512x512_80k_potsdam_20211219_043541-2dd5fe67.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam/pspnet_r50-d8_4x4_512x512_80k_potsdam_20211219_043541.log.json) | +| PSPNet | R-101-D8 | 512x512 | 80000 | 9.61 | 19.40 | 78.62 | 79.47 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam/pspnet_r101-d8_4x4_512x512_80k_potsdam_20211220_125612-aed036c4.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam/pspnet_r101-d8_4x4_512x512_80k_potsdam_20211220_125612.log.json) | + Note: - `FP16` means Mixed Precision (FP16) is adopted in training. diff --git a/configs/pspnet/pspnet.yml b/configs/pspnet/pspnet.yml index b5de88c14..ed10c65de 100644 --- a/configs/pspnet/pspnet.yml +++ b/configs/pspnet/pspnet.yml @@ -11,6 +11,7 @@ Collections: - COCO-Stuff 10k - COCO-Stuff 164k - LoveDA + - Potsdam Paper: URL: https://arxiv.org/abs/1612.01105 Title: Pyramid Scene Parsing Network @@ -808,3 +809,69 @@ Models: mIoU(ms+flip): 51.34 Config: configs/pspnet/pspnet_r101-d8_512x512_80k_loveda.py Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x512_80k_loveda/pspnet_r101-d8_512x512_80k_loveda_20211104_153212-1c06c6a8.pth +- Name: pspnet_r18-d8_4x4_512x512_80k_potsdam + In Collection: pspnet + Metadata: + backbone: R-18-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 11.75 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 1.5 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 77.09 + mIoU(ms+flip): 78.3 + Config: configs/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam/pspnet_r18-d8_4x4_512x512_80k_potsdam_20211220_125612-7cd046e1.pth +- Name: pspnet_r50-d8_4x4_512x512_80k_potsdam + In Collection: pspnet + Metadata: + backbone: R-50-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 33.1 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 6.14 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 78.12 + mIoU(ms+flip): 78.98 + Config: configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam/pspnet_r50-d8_4x4_512x512_80k_potsdam_20211219_043541-2dd5fe67.pth +- Name: pspnet_r101-d8_4x4_512x512_80k_potsdam + In Collection: pspnet + Metadata: + backbone: R-101-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 51.55 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 9.61 + Results: + - Task: Semantic Segmentation + Dataset: Potsdam + Metrics: + mIoU: 78.62 + mIoU(ms+flip): 79.47 + Config: configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam/pspnet_r101-d8_4x4_512x512_80k_potsdam_20211220_125612-aed036c4.pth diff --git a/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam.py b/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam.py new file mode 100644 index 000000000..98343dd76 --- /dev/null +++ b/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_potsdam.py @@ -0,0 +1,2 @@ +_base_ = './pspnet_r50-d8_4x4_512x512_80k_potsdam.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam.py b/configs/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam.py new file mode 100644 index 000000000..be9dc7254 --- /dev/null +++ b/configs/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam.py @@ -0,0 +1,9 @@ +_base_ = './pspnet_r50-d8_4x4_512x512_80k_potsdam.py' +model = dict( + pretrained='open-mmlab://resnet18_v1c', + backbone=dict(depth=18), + decode_head=dict( + in_channels=512, + channels=128, + ), + auxiliary_head=dict(in_channels=256, channels=64)) diff --git a/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam.py b/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam.py new file mode 100644 index 000000000..f78faff0a --- /dev/null +++ b/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_potsdam.py @@ -0,0 +1,6 @@ +_base_ = [ + '../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/potsdam.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +model = dict( + decode_head=dict(num_classes=6), auxiliary_head=dict(num_classes=6)) diff --git a/docs/en/dataset_prepare.md b/docs/en/dataset_prepare.md index 8e89f156c..0a13e2eac 100644 --- a/docs/en/dataset_prepare.md +++ b/docs/en/dataset_prepare.md @@ -116,6 +116,13 @@ mmsegmentation │ │ ├── ann_dir │ │ │ ├── train │ │ │ ├── val +│ ├── potsdam +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ ├── val +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ ├── val ``` ### Cityscapes @@ -286,3 +293,19 @@ python tools/convert_datasets/loveda.py /path/to/loveDA Using trained model to predict test set of LoveDA and submit it to server can be found [here](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/inference.md). More details about LoveDA can be found [here](https://github.com/Junjue-Wang/LoveDA). + +### ISPRS Potsdam + +The [Potsdam](https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-potsdam/) +dataset is for urban semantic segmentation used in the 2D Semantic Labeling Contest - Potsdam. + +The dataset can be requested at the challenge [homepage](https://www2.isprs.org/commissions/comm2/wg4/benchmark/data-request-form/). +The '2_Ortho_RGB.zip' and '5_Labels_all_noBoundary.zip' are required. + +For Potsdam dataset, please run the following command to download and re-organize the dataset. + +```shell +python tools/convert_datasets/potsdam.py /path/to/potsdam +``` + +In our default setting, it will generate 3456 images for training and 2016 images for validation. diff --git a/docs/zh_cn/dataset_prepare.md b/docs/zh_cn/dataset_prepare.md index bffc3d1da..5c171f89a 100644 --- a/docs/zh_cn/dataset_prepare.md +++ b/docs/zh_cn/dataset_prepare.md @@ -97,6 +97,13 @@ mmsegmentation │ │ ├── ann_dir │ │ │ ├── train │ │ │ ├── val +│ ├── potsdam +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ ├── val +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ ├── val ``` ### Cityscapes @@ -228,3 +235,18 @@ python tools/convert_datasets/loveda.py /path/to/loveDA 请参照 [这里](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/inference.md) 来使用训练好的模型去预测 LoveDA 测试集并且提交到官网。 关于 LoveDA 的更多细节可以在[这里](https://github.com/Junjue-Wang/LoveDA) 找到。 + +### ISPRS Potsdam + +[Potsdam](https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-potsdam/) +数据集是一个有着2D 语义分割内容标注的城市遥感数据集。 +数据集可以从挑战[主页](https://www2.isprs.org/commissions/comm2/wg4/benchmark/data-request-form/) 获得。 +需要其中的 '2_Ortho_RGB.zip' 和 '5_Labels_all_noBoundary.zip'。 + +对于 Potsdam 数据集,请运行以下命令下载并重新组织数据集 + +```shell +python tools/convert_datasets/potsdam.py /path/to/potsdam +``` + +使用我们默认的配置, 将生成 3456 张图片的训练集和 2016 张图片的验证集。 diff --git a/mmseg/core/evaluation/class_names.py b/mmseg/core/evaluation/class_names.py index 4527fbaf1..d0c51fd04 100644 --- a/mmseg/core/evaluation/class_names.py +++ b/mmseg/core/evaluation/class_names.py @@ -52,6 +52,22 @@ def voc_classes(): ] +def loveda_classes(): + """LoveDA class names for external use.""" + return [ + 'background', 'building', 'road', 'water', 'barren', 'forest', + 'agricultural' + ] + + +def potsdam_classes(): + """Potsdam class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + def cityscapes_palette(): """Cityscapes palette for external use.""" return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], @@ -112,10 +128,24 @@ def voc_palette(): [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] +def loveda_palette(): + """LoveDA palette for external use.""" + return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], + [159, 129, 183], [0, 255, 0], [255, 195, 128]] + + +def potsdam_palette(): + """Potsdam palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + dataset_aliases = { 'cityscapes': ['cityscapes'], 'ade': ['ade', 'ade20k'], - 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'] + 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'], + 'loveda': ['loveda'], + 'potsdam': ['potsdam'] } diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index a9f80a920..3e4d83165 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -13,6 +13,7 @@ from .hrf import HRFDataset from .loveda import LoveDADataset from .night_driving import NightDrivingDataset from .pascal_context import PascalContextDataset, PascalContextDataset59 +from .potsdam import PotsdamDataset from .stare import STAREDataset from .voc import PascalVOCDataset @@ -22,5 +23,6 @@ __all__ = [ 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', - 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset' + 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', + 'PotsdamDataset' ] diff --git a/mmseg/datasets/potsdam.py b/mmseg/datasets/potsdam.py new file mode 100644 index 000000000..2986b8faa --- /dev/null +++ b/mmseg/datasets/potsdam.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PotsdamDataset(CustomDataset): + """ISPRS Potsdam dataset. + + In segmentation map annotation for Potsdam dataset, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', + 'car', 'clutter') + + PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + def __init__(self, **kwargs): + super(PotsdamDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) diff --git a/tests/data/pseudo_potsdam_dataset/ann_dir/2_10_0_0_512_512.png b/tests/data/pseudo_potsdam_dataset/ann_dir/2_10_0_0_512_512.png new file mode 100644 index 000000000..6f2227851 Binary files /dev/null and b/tests/data/pseudo_potsdam_dataset/ann_dir/2_10_0_0_512_512.png differ diff --git a/tests/data/pseudo_potsdam_dataset/img_dir/2_10_0_0_512_512.png b/tests/data/pseudo_potsdam_dataset/img_dir/2_10_0_0_512_512.png new file mode 100644 index 000000000..7821a1862 Binary files /dev/null and b/tests/data/pseudo_potsdam_dataset/img_dir/2_10_0_0_512_512.png differ diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index 58c7275ab..4c4b81c62 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -15,7 +15,7 @@ from mmseg.core.evaluation import get_classes, get_palette from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, ConcatDataset, CustomDataset, LoveDADataset, MultiImageMixDataset, PascalVOCDataset, - RepeatDataset, build_dataset) + PotsdamDataset, RepeatDataset, build_dataset) def test_classes(): @@ -24,6 +24,8 @@ def test_classes(): 'pascal_voc') assert list( ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k') + assert list(LoveDADataset.CLASSES) == get_classes('loveda') + assert list(PotsdamDataset.CLASSES) == get_classes('potsdam') with pytest.raises(ValueError): get_classes('unsupported') @@ -65,6 +67,8 @@ def test_palette(): assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette( 'pascal_voc') assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k') + assert LoveDADataset.PALETTE == get_palette('loveda') + assert PotsdamDataset.PALETTE == get_palette('potsdam') with pytest.raises(ValueError): get_palette('unsupported') @@ -709,6 +713,16 @@ def test_loveda(): shutil.rmtree('.format_loveda') +def test_potsdam(): + test_dataset = PotsdamDataset( + pipeline=[], + img_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_potsdam_dataset/img_dir'), + ann_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_potsdam_dataset/ann_dir')) + assert len(test_dataset) == 1 + + @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.__getitem__', MagicMock(side_effect=lambda idx: idx)) diff --git a/tools/convert_datasets/potsdam.py b/tools/convert_datasets/potsdam.py new file mode 100644 index 000000000..95a97f6ee --- /dev/null +++ b/tools/convert_datasets/potsdam.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import math +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert potsdam dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='potsdam folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--clip_size', + type=int, + help='clipped size of image after preparation', + default=512) + parser.add_argument( + '--stride_size', + type=int, + help='stride of clipping original images', + default=256) + args = parser.parse_args() + return args + + +def clip_big_image(image_path, clip_save_dir, args, to_label=False): + # Original image of Potsdam dataset is very large, thus pre-processing + # of them is adopted. Given fixed clip size and stride size to generate + # clipped image, the intersection of width and height is determined. + # For example, given one 5120 x 5120 original image, the clip size is + # 512 and stride size is 256, thus it would generate 20x20 = 400 images + # whose size are all 512x512. + image = mmcv.imread(image_path) + + h, w, c = image.shape + clip_size = args.clip_size + stride_size = args.stride_size + + num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil( + (h - clip_size) / + stride_size) * stride_size + clip_size >= h else math.ceil( + (h - clip_size) / stride_size) + 1 + num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil( + (w - clip_size) / + stride_size) * stride_size + clip_size >= w else math.ceil( + (w - clip_size) / stride_size) + 1 + + x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) + xmin = x * clip_size + ymin = y * clip_size + + xmin = xmin.ravel() + ymin = ymin.ravel() + xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size, + np.zeros_like(xmin)) + ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size, + np.zeros_like(ymin)) + boxes = np.stack([ + xmin + xmin_offset, ymin + ymin_offset, + np.minimum(xmin + clip_size, w), + np.minimum(ymin + clip_size, h) + ], + axis=1) + + if to_label: + color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], + [255, 255, 0], [0, 255, 0], [0, 255, 255], + [0, 0, 255]]) + flatten_v = np.matmul( + image.reshape(-1, c), + np.array([2, 3, 4]).reshape(3, 1)) + out = np.zeros_like(flatten_v) + for idx, class_color in enumerate(color_map): + value_idx = np.matmul(class_color, + np.array([2, 3, 4]).reshape(3, 1)) + out[flatten_v == value_idx] = idx + image = out.reshape(h, w) + + for box in boxes: + start_x, start_y, end_x, end_y = box + clipped_image = image[start_y:end_y, + start_x:end_x] if to_label else image[ + start_y:end_y, start_x:end_x, :] + idx_i, idx_j = osp.basename(image_path).split('_')[2:4] + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join( + clip_save_dir, + f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + + +def main(): + args = parse_args() + splits = { + 'train': [ + '2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11', + '4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7', + '6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9' + ], + 'val': [ + '5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13', + '4_15', '2_14', '5_13', '4_13', '3_14', '7_13' + ] + } + + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'potsdam') + else: + out_dir = args.out_dir + + print('Making directories...') + mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + + zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) + print('Find the data', zipp_list) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for zipp in zipp_list: + zip_file = zipfile.ZipFile(zipp) + zip_file.extractall(tmp_dir) + src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + if not len(src_path_list): + sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0]) + src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif')) + + prog_bar = mmcv.ProgressBar(len(src_path_list)) + for i, src_path in enumerate(src_path_list): + idx_i, idx_j = osp.basename(src_path).split('_')[2:4] + data_type = 'train' if f'{idx_i}_{idx_j}' in splits[ + 'train'] else 'val' + if 'label' in src_path: + dst_dir = osp.join(out_dir, 'ann_dir', data_type) + clip_big_image(src_path, dst_dir, args, to_label=True) + else: + dst_dir = osp.join(out_dir, 'img_dir', data_type) + clip_big_image(src_path, dst_dir, args, to_label=False) + prog_bar.update() + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main()