From e5b8d72e01a99393504c0ba91e3f32e740f21ae9 Mon Sep 17 00:00:00 2001 From: Tianlong Ai <50650583+AI-Tianlong@users.noreply.github.com> Date: Mon, 5 Jun 2023 11:25:50 +0800 Subject: [PATCH] [Dataset] Support GID dataset on project (#3038) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Support GID dataset on project --------- Co-authored-by: 谢昕辰 --- .../configs/_base_/datasets/gid.py | 67 +++++++ ...labv3plus_r101-d8_4xb2-240k_gid-256x256.py | 15 ++ projects/gid_dataset/mmseg/datasets/gid.py | 55 ++++++ .../tools/dataset_converters/gid.py | 181 ++++++++++++++++++ .../gid_select15imgFromAll.py | 75 ++++++++ .../user_guides/2_dataset_prepare.md | 53 +++++ 6 files changed, 446 insertions(+) create mode 100644 projects/gid_dataset/configs/_base_/datasets/gid.py create mode 100644 projects/gid_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_gid-256x256.py create mode 100644 projects/gid_dataset/mmseg/datasets/gid.py create mode 100644 projects/gid_dataset/tools/dataset_converters/gid.py create mode 100644 projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py create mode 100644 projects/gid_dataset/user_guides/2_dataset_prepare.md diff --git a/projects/gid_dataset/configs/_base_/datasets/gid.py b/projects/gid_dataset/configs/_base_/datasets/gid.py new file mode 100644 index 000000000..f7218105f --- /dev/null +++ b/projects/gid_dataset/configs/_base_/datasets/gid.py @@ -0,0 +1,67 @@ +# dataset settings +dataset_type = 'GID_Dataset' # 注册的类名 +data_root = 'data/gid/' # 数据集根目录 +crop_size = (256, 256) # 图像裁剪大小 +train_pipeline = [ + dict(type='LoadImageFromFile'), # 从文件中加载图像 + dict(type='LoadAnnotations'), # 从文件中加载标注 + dict( + type='RandomResize', # 随机缩放 + scale=(512, 512), # 缩放尺寸 + ratio_range=(0.5, 2.0), # 缩放比例范围 + keep_ratio=True), # 是否保持长宽比 + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), # 随机裁剪 + dict(type='RandomFlip', prob=0.5), # 随机翻转 + dict(type='PhotoMetricDistortion'), # 图像增强 + dict(type='PackSegInputs') # 打包数据 +] +test_pipeline = [ + dict(type='LoadImageFromFile'), # 从文件中加载图像 + dict(type='Resize', scale=(256, 256), keep_ratio=True), # 缩放 + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadAnnotations'), # 从文件中加载标注 + dict(type='PackSegInputs') # 打包数据 +] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # 多尺度预测缩放比例 +tta_pipeline = [ # 多尺度测试 + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] +train_dataloader = dict( # 训练数据加载器 + batch_size=2, # 训练时的数据批量大小 + num_workers=4, # 数据加载线程数 + persistent_workers=True, # 是否持久化线程 + sampler=dict(type='InfiniteSampler', shuffle=True), # 无限采样器 + dataset=dict( + type=dataset_type, # 数据集类名 + data_root=data_root, # 数据集根目录 + data_prefix=dict( + img_path='img_dir/train', + seg_map_path='ann_dir/train'), # 训练集图像和标注路径 + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, # 验证时的数据批量大小 + num_workers=4, # 数据加载线程数 + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/projects/gid_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_gid-256x256.py b/projects/gid_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_gid-256x256.py new file mode 100644 index 000000000..70cb6005f --- /dev/null +++ b/projects/gid_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_gid-256x256.py @@ -0,0 +1,15 @@ +_base_ = [ + '../../../configs/_base_/models/deeplabv3plus_r50-d8.py', + './_base_/datasets/gid.py', '../../../configs/_base_/default_runtime.py', + '../../../configs/_base_/schedules/schedule_240k.py' +] +custom_imports = dict(imports=['projects.gid_dataset.mmseg.datasets.gid']) + +crop_size = (256, 256) +data_preprocessor = dict(size=crop_size) +model = dict( + data_preprocessor=data_preprocessor, + pretrained='open-mmlab://resnet101_v1c', + backbone=dict(depth=101), + decode_head=dict(num_classes=6), + auxiliary_head=dict(num_classes=6)) diff --git a/projects/gid_dataset/mmseg/datasets/gid.py b/projects/gid_dataset/mmseg/datasets/gid.py new file mode 100644 index 000000000..a9e8c510b --- /dev/null +++ b/projects/gid_dataset/mmseg/datasets/gid.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.datasets.basesegdataset import BaseSegDataset +from mmseg.registry import DATASETS + + +# 注册数据集类 +@DATASETS.register_module() +class GID_Dataset(BaseSegDataset): + """Gaofen Image Dataset (GID) + + Dataset paper link: + https://www.sciencedirect.com/science/article/pii/S0034425719303414 + https://x-ytong.github.io/project/GID.html + + GID 6 classes: others, built-up, farmland, forest, meadow, water + + In this example, select 15 images from GID dataset as training set, + and select 5 images as validation set. + The selected images are listed as follows: + + GF2_PMS1__L1A0000647767-MSS1 + GF2_PMS1__L1A0001064454-MSS1 + GF2_PMS1__L1A0001348919-MSS1 + GF2_PMS1__L1A0001680851-MSS1 + GF2_PMS1__L1A0001680853-MSS1 + GF2_PMS1__L1A0001680857-MSS1 + GF2_PMS1__L1A0001757429-MSS1 + GF2_PMS2__L1A0000607681-MSS2 + GF2_PMS2__L1A0000635115-MSS2 + GF2_PMS2__L1A0000658637-MSS2 + GF2_PMS2__L1A0001206072-MSS2 + GF2_PMS2__L1A0001471436-MSS2 + GF2_PMS2__L1A0001642620-MSS2 + GF2_PMS2__L1A0001787089-MSS2 + GF2_PMS2__L1A0001838560-MSS2 + + The ``img_suffix`` is fixed to '.tif' and ``seg_map_suffix`` is + fixed to '.tif' for GID. + """ + METAINFO = dict( + classes=('Others', 'Built-up', 'Farmland', 'Forest', 'Meadow', + 'Water'), + palette=[[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 255, 255], + [255, 255, 0], [0, 0, 255]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=None, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/projects/gid_dataset/tools/dataset_converters/gid.py b/projects/gid_dataset/tools/dataset_converters/gid.py new file mode 100644 index 000000000..d95654aa1 --- /dev/null +++ b/projects/gid_dataset/tools/dataset_converters/gid.py @@ -0,0 +1,181 @@ +import argparse +import glob +import math +import os +import os.path as osp + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar, mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert GID dataset to mmsegmentation format') + parser.add_argument('dataset_img_path', help='GID images folder path') + parser.add_argument('dataset_label_path', help='GID labels folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument( + '-o', '--out_dir', help='output path', default='data/gid') + parser.add_argument( + '--clip_size', + type=int, + help='clipped size of image after preparation', + default=256) + parser.add_argument( + '--stride_size', + type=int, + help='stride of clipping original images', + default=256) + args = parser.parse_args() + return args + + +GID_COLORMAP = dict( + Background=(0, 0, 0), # 0-背景-黑色 + Building=(255, 0, 0), # 1-建筑-红色 + Farmland=(0, 255, 0), # 2-农田-绿色 + Forest=(0, 0, 255), # 3-森林-蓝色 + Meadow=(255, 255, 0), # 4-草地-黄色 + Water=(0, 0, 255) # 5-水-蓝色 +) +palette = list(GID_COLORMAP.values()) +classes = list(GID_COLORMAP.keys()) + + +# 用列表来存一个 RGB 和一个类别的对应 +def colormap2label(palette): + colormap2label_list = np.zeros(256**3, dtype=np.longlong) + for i, colormap in enumerate(palette): + colormap2label_list[(colormap[0] * 256 + colormap[1]) * 256 + + colormap[2]] = i + return colormap2label_list + + +# 给定那个列表,和vis_png然后生成masks_png +def label_indices(RGB_label, colormap2label_list): + RGB_label = RGB_label.astype('int32') + idx = (RGB_label[:, :, 0] * 256 + + RGB_label[:, :, 1]) * 256 + RGB_label[:, :, 2] + return colormap2label_list[idx] + + +def RGB2mask(RGB_label, colormap2label_list): + mask_label = label_indices(RGB_label, colormap2label_list) + return mask_label + + +colormap2label_list = colormap2label(palette) + + +def clip_big_image(image_path, clip_save_dir, args, to_label=False): + """Original image of GID dataset is very large, thus pre-processing of them + is adopted. + + Given fixed clip size and stride size to generate + clipped image, the intersection of width and height is determined. + For example, given one 6800 x 7200 original image, the clip size is + 256 and stride size is 256, thus it would generate 29 x 27 = 783 images + whose size are all 256 x 256. + """ + + image = mmcv.imread(image_path, channel_order='rgb') + # image = mmcv.bgr2gray(image) + + h, w, c = image.shape + clip_size = args.clip_size + stride_size = args.stride_size + + num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil( + (h - clip_size) / + stride_size) * stride_size + clip_size >= h else math.ceil( + (h - clip_size) / stride_size) + 1 + num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil( + (w - clip_size) / + stride_size) * stride_size + clip_size >= w else math.ceil( + (w - clip_size) / stride_size) + 1 + + x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) + xmin = x * clip_size + ymin = y * clip_size + + xmin = xmin.ravel() + ymin = ymin.ravel() + xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size, + np.zeros_like(xmin)) + ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size, + np.zeros_like(ymin)) + boxes = np.stack([ + xmin + xmin_offset, ymin + ymin_offset, + np.minimum(xmin + clip_size, w), + np.minimum(ymin + clip_size, h) + ], + axis=1) + + if to_label: + image = RGB2mask(image, colormap2label_list) + + for count, box in enumerate(boxes): + start_x, start_y, end_x, end_y = box + clipped_image = image[start_y:end_y, + start_x:end_x] if to_label else image[ + start_y:end_y, start_x:end_x, :] + img_name = osp.basename(image_path).replace('.tif', '') + img_name = img_name.replace('_label', '') + if count % 3 == 0: + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join( + clip_save_dir.replace('train', 'val'), + f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + else: + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join( + clip_save_dir, + f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + count += 1 + + +def main(): + args = parse_args() + """ + According to this paper: https://ieeexplore.ieee.org/document/9343296/ + select 15 images contained in GID, , which cover the whole six + categories, to generate train set and validation set. + + """ + + if args.out_dir is None: + out_dir = osp.join('data', 'gid') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + + src_path_list = glob.glob(os.path.join(args.dataset_img_path, '*.tif')) + print(f'Find {len(src_path_list)} pictures') + + prog_bar = ProgressBar(len(src_path_list)) + + dst_img_dir = osp.join(out_dir, 'img_dir', 'train') + dst_label_dir = osp.join(out_dir, 'ann_dir', 'train') + + for i, img_path in enumerate(src_path_list): + label_path = osp.join( + args.dataset_label_path, + osp.basename(img_path.replace('.tif', '_label.tif'))) + + clip_big_image(img_path, dst_img_dir, args, to_label=False) + clip_big_image(label_path, dst_label_dir, args, to_label=True) + prog_bar.update() + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py b/projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py new file mode 100644 index 000000000..d3eeff269 --- /dev/null +++ b/projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py @@ -0,0 +1,75 @@ +import argparse +import os +import shutil + +# select 15 images from GID dataset + +img_list = [ + 'GF2_PMS1__L1A0000647767-MSS1.tif', 'GF2_PMS1__L1A0001064454-MSS1.tif', + 'GF2_PMS1__L1A0001348919-MSS1.tif', 'GF2_PMS1__L1A0001680851-MSS1.tif', + 'GF2_PMS1__L1A0001680853-MSS1.tif', 'GF2_PMS1__L1A0001680857-MSS1.tif', + 'GF2_PMS1__L1A0001757429-MSS1.tif', 'GF2_PMS2__L1A0000607681-MSS2.tif', + 'GF2_PMS2__L1A0000635115-MSS2.tif', 'GF2_PMS2__L1A0000658637-MSS2.tif', + 'GF2_PMS2__L1A0001206072-MSS2.tif', 'GF2_PMS2__L1A0001471436-MSS2.tif', + 'GF2_PMS2__L1A0001642620-MSS2.tif', 'GF2_PMS2__L1A0001787089-MSS2.tif', + 'GF2_PMS2__L1A0001838560-MSS2.tif' +] + +labels_list = [ + 'GF2_PMS1__L1A0000647767-MSS1_label.tif', + 'GF2_PMS1__L1A0001064454-MSS1_label.tif', + 'GF2_PMS1__L1A0001348919-MSS1_label.tif', + 'GF2_PMS1__L1A0001680851-MSS1_label.tif', + 'GF2_PMS1__L1A0001680853-MSS1_label.tif', + 'GF2_PMS1__L1A0001680857-MSS1_label.tif', + 'GF2_PMS1__L1A0001757429-MSS1_label.tif', + 'GF2_PMS2__L1A0000607681-MSS2_label.tif', + 'GF2_PMS2__L1A0000635115-MSS2_label.tif', + 'GF2_PMS2__L1A0000658637-MSS2_label.tif', + 'GF2_PMS2__L1A0001206072-MSS2_label.tif', + 'GF2_PMS2__L1A0001471436-MSS2_label.tif', + 'GF2_PMS2__L1A0001642620-MSS2_label.tif', + 'GF2_PMS2__L1A0001787089-MSS2_label.tif', + 'GF2_PMS2__L1A0001838560-MSS2_label.tif' +] + + +def parse_args(): + parser = argparse.ArgumentParser( + description='From 150 images of GID dataset to select 15 images') + parser.add_argument('dataset_img_dir', help='150 GID images folder path') + parser.add_argument('dataset_label_dir', help='150 GID labels folder path') + + parser.add_argument('dest_img_dir', help='15 GID images folder path') + parser.add_argument('dest_label_dir', help='15 GID labels folder path') + + args = parser.parse_args() + + return args + + +def main(): + """This script is used to select 15 images from GID dataset, According to + paper: https://ieeexplore.ieee.org/document/9343296/""" + args = parse_args() + + img_path = args.dataset_img_dir + label_path = args.dataset_label_dir + + dest_img_dir = args.dest_img_dir + dest_label_dir = args.dest_label_dir + + # copy images of 'img_list' to 'desr_dir' + print('Copy images of img_list to desr_dir ing...') + for img in img_list: + shutil.copy(os.path.join(img_path, img), dest_img_dir) + print('Done!') + + print('copy labels of labels_list to desr_dir ing...') + for label in labels_list: + shutil.copy(os.path.join(label_path, label), dest_label_dir) + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/projects/gid_dataset/user_guides/2_dataset_prepare.md b/projects/gid_dataset/user_guides/2_dataset_prepare.md new file mode 100644 index 000000000..63bd4d46f --- /dev/null +++ b/projects/gid_dataset/user_guides/2_dataset_prepare.md @@ -0,0 +1,53 @@ +## Gaofen Image Dataset (GID) + +- GID 数据集可在[此处](https://x-ytong.github.io/project/GID.html)进行下载。 +- GID 数据集包含 150 张 6800x7200 的大尺寸图像,标签为 RGB 标签。 +- 根据[文献](https://ieeexplore.ieee.org/document/9343296/),此处选择 15 张图像生成训练集和验证集,该 15 张图像包含了所有六类信息。所选的图像名称如下: + +```None + GF2_PMS1__L1A0000647767-MSS1 + GF2_PMS1__L1A0001064454-MSS1 + GF2_PMS1__L1A0001348919-MSS1 + GF2_PMS1__L1A0001680851-MSS1 + GF2_PMS1__L1A0001680853-MSS1 + GF2_PMS1__L1A0001680857-MSS1 + GF2_PMS1__L1A0001757429-MSS1 + GF2_PMS2__L1A0000607681-MSS2 + GF2_PMS2__L1A0000635115-MSS2 + GF2_PMS2__L1A0000658637-MSS2 + GF2_PMS2__L1A0001206072-MSS2 + GF2_PMS2__L1A0001471436-MSS2 + GF2_PMS2__L1A0001642620-MSS2 + GF2_PMS2__L1A0001787089-MSS2 + GF2_PMS2__L1A0001838560-MSS2 +``` + +这里也提供了一个脚本来方便的筛选出15张图像, + +``` +python projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py {150 张图像的路径} {150 张标签的路径} {15 张图像的路径} {15 张标签的路径} +``` + +在选择出 15 张图像后,执行以下命令进行裁切及标签的转换,需要修改为您所存储 15 张图像及标签的路径。 + +``` +python projects/gid_dataset/tools/dataset_converters/gid.py {15 张图像的路径} {15 张标签的路径} +``` + +完成裁切后的 GID 数据结构如下: + +```none +mmsegmentation +├── mmseg +├── tools +├── configs +├── data +│ ├── gid +│ │ ├── ann_dir +| │ │ │ ├── train +| │ │ │ ├── val +│ │ ├── img_dir +| │ │ │ ├── train +| │ │ │ ├── val + +```