mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Motivation Support GID dataset on project --------- Co-authored-by: 谢昕辰 <xiexinch@outlook.com>
182 lines
5.9 KiB
Python
182 lines
5.9 KiB
Python
import argparse
|
||
import glob
|
||
import math
|
||
import os
|
||
import os.path as osp
|
||
|
||
import mmcv
|
||
import numpy as np
|
||
from mmengine.utils import ProgressBar, mkdir_or_exist
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(
|
||
description='Convert GID dataset to mmsegmentation format')
|
||
parser.add_argument('dataset_img_path', help='GID images folder path')
|
||
parser.add_argument('dataset_label_path', help='GID labels folder path')
|
||
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||
parser.add_argument(
|
||
'-o', '--out_dir', help='output path', default='data/gid')
|
||
parser.add_argument(
|
||
'--clip_size',
|
||
type=int,
|
||
help='clipped size of image after preparation',
|
||
default=256)
|
||
parser.add_argument(
|
||
'--stride_size',
|
||
type=int,
|
||
help='stride of clipping original images',
|
||
default=256)
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
GID_COLORMAP = dict(
|
||
Background=(0, 0, 0), # 0-背景-黑色
|
||
Building=(255, 0, 0), # 1-建筑-红色
|
||
Farmland=(0, 255, 0), # 2-农田-绿色
|
||
Forest=(0, 0, 255), # 3-森林-蓝色
|
||
Meadow=(255, 255, 0), # 4-草地-黄色
|
||
Water=(0, 0, 255) # 5-水-蓝色
|
||
)
|
||
palette = list(GID_COLORMAP.values())
|
||
classes = list(GID_COLORMAP.keys())
|
||
|
||
|
||
# 用列表来存一个 RGB 和一个类别的对应
|
||
def colormap2label(palette):
|
||
colormap2label_list = np.zeros(256**3, dtype=np.longlong)
|
||
for i, colormap in enumerate(palette):
|
||
colormap2label_list[(colormap[0] * 256 + colormap[1]) * 256 +
|
||
colormap[2]] = i
|
||
return colormap2label_list
|
||
|
||
|
||
# 给定那个列表,和vis_png然后生成masks_png
|
||
def label_indices(RGB_label, colormap2label_list):
|
||
RGB_label = RGB_label.astype('int32')
|
||
idx = (RGB_label[:, :, 0] * 256 +
|
||
RGB_label[:, :, 1]) * 256 + RGB_label[:, :, 2]
|
||
return colormap2label_list[idx]
|
||
|
||
|
||
def RGB2mask(RGB_label, colormap2label_list):
|
||
mask_label = label_indices(RGB_label, colormap2label_list)
|
||
return mask_label
|
||
|
||
|
||
colormap2label_list = colormap2label(palette)
|
||
|
||
|
||
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
|
||
"""Original image of GID dataset is very large, thus pre-processing of them
|
||
is adopted.
|
||
|
||
Given fixed clip size and stride size to generate
|
||
clipped image, the intersection of width and height is determined.
|
||
For example, given one 6800 x 7200 original image, the clip size is
|
||
256 and stride size is 256, thus it would generate 29 x 27 = 783 images
|
||
whose size are all 256 x 256.
|
||
"""
|
||
|
||
image = mmcv.imread(image_path, channel_order='rgb')
|
||
# image = mmcv.bgr2gray(image)
|
||
|
||
h, w, c = image.shape
|
||
clip_size = args.clip_size
|
||
stride_size = args.stride_size
|
||
|
||
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
|
||
(h - clip_size) /
|
||
stride_size) * stride_size + clip_size >= h else math.ceil(
|
||
(h - clip_size) / stride_size) + 1
|
||
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
|
||
(w - clip_size) /
|
||
stride_size) * stride_size + clip_size >= w else math.ceil(
|
||
(w - clip_size) / stride_size) + 1
|
||
|
||
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
||
xmin = x * clip_size
|
||
ymin = y * clip_size
|
||
|
||
xmin = xmin.ravel()
|
||
ymin = ymin.ravel()
|
||
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
|
||
np.zeros_like(xmin))
|
||
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
|
||
np.zeros_like(ymin))
|
||
boxes = np.stack([
|
||
xmin + xmin_offset, ymin + ymin_offset,
|
||
np.minimum(xmin + clip_size, w),
|
||
np.minimum(ymin + clip_size, h)
|
||
],
|
||
axis=1)
|
||
|
||
if to_label:
|
||
image = RGB2mask(image, colormap2label_list)
|
||
|
||
for count, box in enumerate(boxes):
|
||
start_x, start_y, end_x, end_y = box
|
||
clipped_image = image[start_y:end_y,
|
||
start_x:end_x] if to_label else image[
|
||
start_y:end_y, start_x:end_x, :]
|
||
img_name = osp.basename(image_path).replace('.tif', '')
|
||
img_name = img_name.replace('_label', '')
|
||
if count % 3 == 0:
|
||
mmcv.imwrite(
|
||
clipped_image.astype(np.uint8),
|
||
osp.join(
|
||
clip_save_dir.replace('train', 'val'),
|
||
f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
||
else:
|
||
mmcv.imwrite(
|
||
clipped_image.astype(np.uint8),
|
||
osp.join(
|
||
clip_save_dir,
|
||
f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
||
count += 1
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
"""
|
||
According to this paper: https://ieeexplore.ieee.org/document/9343296/
|
||
select 15 images contained in GID, , which cover the whole six
|
||
categories, to generate train set and validation set.
|
||
|
||
"""
|
||
|
||
if args.out_dir is None:
|
||
out_dir = osp.join('data', 'gid')
|
||
else:
|
||
out_dir = args.out_dir
|
||
|
||
print('Making directories...')
|
||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
||
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
||
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
||
|
||
src_path_list = glob.glob(os.path.join(args.dataset_img_path, '*.tif'))
|
||
print(f'Find {len(src_path_list)} pictures')
|
||
|
||
prog_bar = ProgressBar(len(src_path_list))
|
||
|
||
dst_img_dir = osp.join(out_dir, 'img_dir', 'train')
|
||
dst_label_dir = osp.join(out_dir, 'ann_dir', 'train')
|
||
|
||
for i, img_path in enumerate(src_path_list):
|
||
label_path = osp.join(
|
||
args.dataset_label_path,
|
||
osp.basename(img_path.replace('.tif', '_label.tif')))
|
||
|
||
clip_big_image(img_path, dst_img_dir, args, to_label=False)
|
||
clip_big_image(label_path, dst_label_dir, args, to_label=True)
|
||
prog_bar.update()
|
||
|
||
print('Done!')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|