Tianlong Ai e5b8d72e01
[Dataset] Support GID dataset on project (#3038)
## Motivation
Support GID dataset on project

---------

Co-authored-by: 谢昕辰 <xiexinch@outlook.com>
2023-06-05 11:25:50 +08:00

182 lines
5.9 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import glob
import math
import os
import os.path as osp
import mmcv
import numpy as np
from mmengine.utils import ProgressBar, mkdir_or_exist
def parse_args():
parser = argparse.ArgumentParser(
description='Convert GID dataset to mmsegmentation format')
parser.add_argument('dataset_img_path', help='GID images folder path')
parser.add_argument('dataset_label_path', help='GID labels folder path')
parser.add_argument('--tmp_dir', help='path of the temporary directory')
parser.add_argument(
'-o', '--out_dir', help='output path', default='data/gid')
parser.add_argument(
'--clip_size',
type=int,
help='clipped size of image after preparation',
default=256)
parser.add_argument(
'--stride_size',
type=int,
help='stride of clipping original images',
default=256)
args = parser.parse_args()
return args
GID_COLORMAP = dict(
Background=(0, 0, 0), # 0-背景-黑色
Building=(255, 0, 0), # 1-建筑-红色
Farmland=(0, 255, 0), # 2-农田-绿色
Forest=(0, 0, 255), # 3-森林-蓝色
Meadow=(255, 255, 0), # 4-草地-黄色
Water=(0, 0, 255) # 5-水-蓝色
)
palette = list(GID_COLORMAP.values())
classes = list(GID_COLORMAP.keys())
# 用列表来存一个 RGB 和一个类别的对应
def colormap2label(palette):
colormap2label_list = np.zeros(256**3, dtype=np.longlong)
for i, colormap in enumerate(palette):
colormap2label_list[(colormap[0] * 256 + colormap[1]) * 256 +
colormap[2]] = i
return colormap2label_list
# 给定那个列表和vis_png然后生成masks_png
def label_indices(RGB_label, colormap2label_list):
RGB_label = RGB_label.astype('int32')
idx = (RGB_label[:, :, 0] * 256 +
RGB_label[:, :, 1]) * 256 + RGB_label[:, :, 2]
return colormap2label_list[idx]
def RGB2mask(RGB_label, colormap2label_list):
mask_label = label_indices(RGB_label, colormap2label_list)
return mask_label
colormap2label_list = colormap2label(palette)
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
"""Original image of GID dataset is very large, thus pre-processing of them
is adopted.
Given fixed clip size and stride size to generate
clipped image, the intersection of width and height is determined.
For example, given one 6800 x 7200 original image, the clip size is
256 and stride size is 256, thus it would generate 29 x 27 = 783 images
whose size are all 256 x 256.
"""
image = mmcv.imread(image_path, channel_order='rgb')
# image = mmcv.bgr2gray(image)
h, w, c = image.shape
clip_size = args.clip_size
stride_size = args.stride_size
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
(h - clip_size) /
stride_size) * stride_size + clip_size >= h else math.ceil(
(h - clip_size) / stride_size) + 1
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
(w - clip_size) /
stride_size) * stride_size + clip_size >= w else math.ceil(
(w - clip_size) / stride_size) + 1
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
xmin = x * clip_size
ymin = y * clip_size
xmin = xmin.ravel()
ymin = ymin.ravel()
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
np.zeros_like(xmin))
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
np.zeros_like(ymin))
boxes = np.stack([
xmin + xmin_offset, ymin + ymin_offset,
np.minimum(xmin + clip_size, w),
np.minimum(ymin + clip_size, h)
],
axis=1)
if to_label:
image = RGB2mask(image, colormap2label_list)
for count, box in enumerate(boxes):
start_x, start_y, end_x, end_y = box
clipped_image = image[start_y:end_y,
start_x:end_x] if to_label else image[
start_y:end_y, start_x:end_x, :]
img_name = osp.basename(image_path).replace('.tif', '')
img_name = img_name.replace('_label', '')
if count % 3 == 0:
mmcv.imwrite(
clipped_image.astype(np.uint8),
osp.join(
clip_save_dir.replace('train', 'val'),
f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
else:
mmcv.imwrite(
clipped_image.astype(np.uint8),
osp.join(
clip_save_dir,
f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
count += 1
def main():
args = parse_args()
"""
According to this paper: https://ieeexplore.ieee.org/document/9343296/
select 15 images contained in GID, , which cover the whole six
categories, to generate train set and validation set.
"""
if args.out_dir is None:
out_dir = osp.join('data', 'gid')
else:
out_dir = args.out_dir
print('Making directories...')
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
src_path_list = glob.glob(os.path.join(args.dataset_img_path, '*.tif'))
print(f'Find {len(src_path_list)} pictures')
prog_bar = ProgressBar(len(src_path_list))
dst_img_dir = osp.join(out_dir, 'img_dir', 'train')
dst_label_dir = osp.join(out_dir, 'ann_dir', 'train')
for i, img_path in enumerate(src_path_list):
label_path = osp.join(
args.dataset_label_path,
osp.basename(img_path.replace('.tif', '_label.tif')))
clip_big_image(img_path, dst_img_dir, args, to_label=False)
clip_big_image(label_path, dst_label_dir, args, to_label=True)
prog_bar.update()
print('Done!')
if __name__ == '__main__':
main()