[Dataset] Support GID dataset on project (#3038)

## Motivation
Support GID dataset on project

---------

Co-authored-by: 谢昕辰 <xiexinch@outlook.com>
This commit is contained in:
Tianlong Ai 2023-06-05 11:25:50 +08:00 committed by GitHub
parent b6ec4ab1e6
commit e5b8d72e01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 446 additions and 0 deletions

View File

@ -0,0 +1,67 @@
# dataset settings
dataset_type = 'GID_Dataset' # 注册的类名
data_root = 'data/gid/' # 数据集根目录
crop_size = (256, 256) # 图像裁剪大小
train_pipeline = [
dict(type='LoadImageFromFile'), # 从文件中加载图像
dict(type='LoadAnnotations'), # 从文件中加载标注
dict(
type='RandomResize', # 随机缩放
scale=(512, 512), # 缩放尺寸
ratio_range=(0.5, 2.0), # 缩放比例范围
keep_ratio=True), # 是否保持长宽比
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), # 随机裁剪
dict(type='RandomFlip', prob=0.5), # 随机翻转
dict(type='PhotoMetricDistortion'), # 图像增强
dict(type='PackSegInputs') # 打包数据
]
test_pipeline = [
dict(type='LoadImageFromFile'), # 从文件中加载图像
dict(type='Resize', scale=(256, 256), keep_ratio=True), # 缩放
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'), # 从文件中加载标注
dict(type='PackSegInputs') # 打包数据
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # 多尺度预测缩放比例
tta_pipeline = [ # 多尺度测试
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( # 训练数据加载器
batch_size=2, # 训练时的数据批量大小
num_workers=4, # 数据加载线程数
persistent_workers=True, # 是否持久化线程
sampler=dict(type='InfiniteSampler', shuffle=True), # 无限采样器
dataset=dict(
type=dataset_type, # 数据集类名
data_root=data_root, # 数据集根目录
data_prefix=dict(
img_path='img_dir/train',
seg_map_path='ann_dir/train'), # 训练集图像和标注路径
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1, # 验证时的数据批量大小
num_workers=4, # 数据加载线程数
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -0,0 +1,15 @@
_base_ = [
'../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
'./_base_/datasets/gid.py', '../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_240k.py'
]
custom_imports = dict(imports=['projects.gid_dataset.mmseg.datasets.gid'])
crop_size = (256, 256)
data_preprocessor = dict(size=crop_size)
model = dict(
data_preprocessor=data_preprocessor,
pretrained='open-mmlab://resnet101_v1c',
backbone=dict(depth=101),
decode_head=dict(num_classes=6),
auxiliary_head=dict(num_classes=6))

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.datasets.basesegdataset import BaseSegDataset
from mmseg.registry import DATASETS
# 注册数据集类
@DATASETS.register_module()
class GID_Dataset(BaseSegDataset):
"""Gaofen Image Dataset (GID)
Dataset paper link:
https://www.sciencedirect.com/science/article/pii/S0034425719303414
https://x-ytong.github.io/project/GID.html
GID 6 classes: others, built-up, farmland, forest, meadow, water
In this example, select 15 images from GID dataset as training set,
and select 5 images as validation set.
The selected images are listed as follows:
GF2_PMS1__L1A0000647767-MSS1
GF2_PMS1__L1A0001064454-MSS1
GF2_PMS1__L1A0001348919-MSS1
GF2_PMS1__L1A0001680851-MSS1
GF2_PMS1__L1A0001680853-MSS1
GF2_PMS1__L1A0001680857-MSS1
GF2_PMS1__L1A0001757429-MSS1
GF2_PMS2__L1A0000607681-MSS2
GF2_PMS2__L1A0000635115-MSS2
GF2_PMS2__L1A0000658637-MSS2
GF2_PMS2__L1A0001206072-MSS2
GF2_PMS2__L1A0001471436-MSS2
GF2_PMS2__L1A0001642620-MSS2
GF2_PMS2__L1A0001787089-MSS2
GF2_PMS2__L1A0001838560-MSS2
The ``img_suffix`` is fixed to '.tif' and ``seg_map_suffix`` is
fixed to '.tif' for GID.
"""
METAINFO = dict(
classes=('Others', 'Built-up', 'Farmland', 'Forest', 'Meadow',
'Water'),
palette=[[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 255, 255],
[255, 255, 0], [0, 0, 255]])
def __init__(self,
img_suffix='.png',
seg_map_suffix='.png',
reduce_zero_label=None,
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)

View File

@ -0,0 +1,181 @@
import argparse
import glob
import math
import os
import os.path as osp
import mmcv
import numpy as np
from mmengine.utils import ProgressBar, mkdir_or_exist
def parse_args():
parser = argparse.ArgumentParser(
description='Convert GID dataset to mmsegmentation format')
parser.add_argument('dataset_img_path', help='GID images folder path')
parser.add_argument('dataset_label_path', help='GID labels folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument(
'-o', '--out_dir', help='output path', default='data/gid')
parser.add_argument(
'--clip_size',
type=int,
help='clipped size of image after preparation',
default=256)
parser.add_argument(
'--stride_size',
type=int,
help='stride of clipping original images',
default=256)
args = parser.parse_args()
return args
GID_COLORMAP = dict(
Background=(0, 0, 0), # 0-背景-黑色
Building=(255, 0, 0), # 1-建筑-红色
Farmland=(0, 255, 0), # 2-农田-绿色
Forest=(0, 0, 255), # 3-森林-蓝色
Meadow=(255, 255, 0), # 4-草地-黄色
Water=(0, 0, 255) # 5-水-蓝色
)
palette = list(GID_COLORMAP.values())
classes = list(GID_COLORMAP.keys())
# 用列表来存一个 RGB 和一个类别的对应
def colormap2label(palette):
colormap2label_list = np.zeros(256**3, dtype=np.longlong)
for i, colormap in enumerate(palette):
colormap2label_list[(colormap[0] * 256 + colormap[1]) * 256 +
colormap[2]] = i
return colormap2label_list
# 给定那个列表和vis_png然后生成masks_png
def label_indices(RGB_label, colormap2label_list):
RGB_label = RGB_label.astype('int32')
idx = (RGB_label[:, :, 0] * 256 +
RGB_label[:, :, 1]) * 256 + RGB_label[:, :, 2]
return colormap2label_list[idx]
def RGB2mask(RGB_label, colormap2label_list):
mask_label = label_indices(RGB_label, colormap2label_list)
return mask_label
colormap2label_list = colormap2label(palette)
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
"""Original image of GID dataset is very large, thus pre-processing of them
is adopted.
Given fixed clip size and stride size to generate
clipped image, the intersection of width and height is determined.
For example, given one 6800 x 7200 original image, the clip size is
256 and stride size is 256, thus it would generate 29 x 27 = 783 images
whose size are all 256 x 256.
"""
image = mmcv.imread(image_path, channel_order='rgb')
# image = mmcv.bgr2gray(image)
h, w, c = image.shape
clip_size = args.clip_size
stride_size = args.stride_size
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
(h - clip_size) /
stride_size) * stride_size + clip_size >= h else math.ceil(
(h - clip_size) / stride_size) + 1
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
(w - clip_size) /
stride_size) * stride_size + clip_size >= w else math.ceil(
(w - clip_size) / stride_size) + 1
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
xmin = x * clip_size
ymin = y * clip_size
xmin = xmin.ravel()
ymin = ymin.ravel()
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
np.zeros_like(xmin))
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
np.zeros_like(ymin))
boxes = np.stack([
xmin + xmin_offset, ymin + ymin_offset,
np.minimum(xmin + clip_size, w),
np.minimum(ymin + clip_size, h)
],
axis=1)
if to_label:
image = RGB2mask(image, colormap2label_list)
for count, box in enumerate(boxes):
start_x, start_y, end_x, end_y = box
clipped_image = image[start_y:end_y,
start_x:end_x] if to_label else image[
start_y:end_y, start_x:end_x, :]
img_name = osp.basename(image_path).replace('.tif', '')
img_name = img_name.replace('_label', '')
if count % 3 == 0:
mmcv.imwrite(
clipped_image.astype(np.uint8),
osp.join(
clip_save_dir.replace('train', 'val'),
f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
else:
mmcv.imwrite(
clipped_image.astype(np.uint8),
osp.join(
clip_save_dir,
f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
count += 1
def main():
args = parse_args()
"""
According to this paper: https://ieeexplore.ieee.org/document/9343296/
select 15 images contained in GID, , which cover the whole six
categories, to generate train set and validation set.
"""
if args.out_dir is None:
out_dir = osp.join('data', 'gid')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
src_path_list = glob.glob(os.path.join(args.dataset_img_path, '*.tif'))
print(f'Find {len(src_path_list)} pictures')
prog_bar = ProgressBar(len(src_path_list))
dst_img_dir = osp.join(out_dir, 'img_dir', 'train')
dst_label_dir = osp.join(out_dir, 'ann_dir', 'train')
for i, img_path in enumerate(src_path_list):
label_path = osp.join(
args.dataset_label_path,
osp.basename(img_path.replace('.tif', '_label.tif')))
clip_big_image(img_path, dst_img_dir, args, to_label=False)
clip_big_image(label_path, dst_label_dir, args, to_label=True)
prog_bar.update()
print('Done!')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,75 @@
import argparse
import os
import shutil
# select 15 images from GID dataset
img_list = [
'GF2_PMS1__L1A0000647767-MSS1.tif', 'GF2_PMS1__L1A0001064454-MSS1.tif',
'GF2_PMS1__L1A0001348919-MSS1.tif', 'GF2_PMS1__L1A0001680851-MSS1.tif',
'GF2_PMS1__L1A0001680853-MSS1.tif', 'GF2_PMS1__L1A0001680857-MSS1.tif',
'GF2_PMS1__L1A0001757429-MSS1.tif', 'GF2_PMS2__L1A0000607681-MSS2.tif',
'GF2_PMS2__L1A0000635115-MSS2.tif', 'GF2_PMS2__L1A0000658637-MSS2.tif',
'GF2_PMS2__L1A0001206072-MSS2.tif', 'GF2_PMS2__L1A0001471436-MSS2.tif',
'GF2_PMS2__L1A0001642620-MSS2.tif', 'GF2_PMS2__L1A0001787089-MSS2.tif',
'GF2_PMS2__L1A0001838560-MSS2.tif'
]
labels_list = [
'GF2_PMS1__L1A0000647767-MSS1_label.tif',
'GF2_PMS1__L1A0001064454-MSS1_label.tif',
'GF2_PMS1__L1A0001348919-MSS1_label.tif',
'GF2_PMS1__L1A0001680851-MSS1_label.tif',
'GF2_PMS1__L1A0001680853-MSS1_label.tif',
'GF2_PMS1__L1A0001680857-MSS1_label.tif',
'GF2_PMS1__L1A0001757429-MSS1_label.tif',
'GF2_PMS2__L1A0000607681-MSS2_label.tif',
'GF2_PMS2__L1A0000635115-MSS2_label.tif',
'GF2_PMS2__L1A0000658637-MSS2_label.tif',
'GF2_PMS2__L1A0001206072-MSS2_label.tif',
'GF2_PMS2__L1A0001471436-MSS2_label.tif',
'GF2_PMS2__L1A0001642620-MSS2_label.tif',
'GF2_PMS2__L1A0001787089-MSS2_label.tif',
'GF2_PMS2__L1A0001838560-MSS2_label.tif'
]
def parse_args():
parser = argparse.ArgumentParser(
description='From 150 images of GID dataset to select 15 images')
parser.add_argument('dataset_img_dir', help='150 GID images folder path')
parser.add_argument('dataset_label_dir', help='150 GID labels folder path')
parser.add_argument('dest_img_dir', help='15 GID images folder path')
parser.add_argument('dest_label_dir', help='15 GID labels folder path')
args = parser.parse_args()
return args
def main():
"""This script is used to select 15 images from GID dataset, According to
paper: https://ieeexplore.ieee.org/document/9343296/"""
args = parse_args()
img_path = args.dataset_img_dir
label_path = args.dataset_label_dir
dest_img_dir = args.dest_img_dir
dest_label_dir = args.dest_label_dir
# copy images of 'img_list' to 'desr_dir'
print('Copy images of img_list to desr_dir ing...')
for img in img_list:
shutil.copy(os.path.join(img_path, img), dest_img_dir)
print('Done!')
print('copy labels of labels_list to desr_dir ing...')
for label in labels_list:
shutil.copy(os.path.join(label_path, label), dest_label_dir)
print('Done!')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,53 @@
## Gaofen Image Dataset (GID)
- GID 数据集可在[此处](https://x-ytong.github.io/project/GID.html)进行下载。
- GID 数据集包含 150 张 6800x7200 的大尺寸图像,标签为 RGB 标签。
- 根据[文献](https://ieeexplore.ieee.org/document/9343296/),此处选择 15 张图像生成训练集和验证集,该 15 张图像包含了所有六类信息。所选的图像名称如下:
```None
GF2_PMS1__L1A0000647767-MSS1
GF2_PMS1__L1A0001064454-MSS1
GF2_PMS1__L1A0001348919-MSS1
GF2_PMS1__L1A0001680851-MSS1
GF2_PMS1__L1A0001680853-MSS1
GF2_PMS1__L1A0001680857-MSS1
GF2_PMS1__L1A0001757429-MSS1
GF2_PMS2__L1A0000607681-MSS2
GF2_PMS2__L1A0000635115-MSS2
GF2_PMS2__L1A0000658637-MSS2
GF2_PMS2__L1A0001206072-MSS2
GF2_PMS2__L1A0001471436-MSS2
GF2_PMS2__L1A0001642620-MSS2
GF2_PMS2__L1A0001787089-MSS2
GF2_PMS2__L1A0001838560-MSS2
```
这里也提供了一个脚本来方便的筛选出15张图像
```
python projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py {150 张图像的路径} {150 张标签的路径} {15 张图像的路径} {15 张标签的路径}
```
在选择出 15 张图像后,执行以下命令进行裁切及标签的转换,需要修改为您所存储 15 张图像及标签的路径。
```
python projects/gid_dataset/tools/dataset_converters/gid.py {15 张图像的路径} {15 张标签的路径}
```
完成裁切后的 GID 数据结构如下:
```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── gid
│ │ ├── ann_dir
| │ │ │ ├── train
| │ │ │ ├── val
│ │ ├── img_dir
| │ │ │ ├── train
| │ │ │ ├── val
```