mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Dataset] Support GID dataset on project (#3038)
## Motivation Support GID dataset on project --------- Co-authored-by: 谢昕辰 <xiexinch@outlook.com>
This commit is contained in:
parent
b6ec4ab1e6
commit
e5b8d72e01
67
projects/gid_dataset/configs/_base_/datasets/gid.py
Normal file
67
projects/gid_dataset/configs/_base_/datasets/gid.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# dataset settings
|
||||||
|
dataset_type = 'GID_Dataset' # 注册的类名
|
||||||
|
data_root = 'data/gid/' # 数据集根目录
|
||||||
|
crop_size = (256, 256) # 图像裁剪大小
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'), # 从文件中加载图像
|
||||||
|
dict(type='LoadAnnotations'), # 从文件中加载标注
|
||||||
|
dict(
|
||||||
|
type='RandomResize', # 随机缩放
|
||||||
|
scale=(512, 512), # 缩放尺寸
|
||||||
|
ratio_range=(0.5, 2.0), # 缩放比例范围
|
||||||
|
keep_ratio=True), # 是否保持长宽比
|
||||||
|
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), # 随机裁剪
|
||||||
|
dict(type='RandomFlip', prob=0.5), # 随机翻转
|
||||||
|
dict(type='PhotoMetricDistortion'), # 图像增强
|
||||||
|
dict(type='PackSegInputs') # 打包数据
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'), # 从文件中加载图像
|
||||||
|
dict(type='Resize', scale=(256, 256), keep_ratio=True), # 缩放
|
||||||
|
# add loading annotation after ``Resize`` because ground truth
|
||||||
|
# does not need to do resize data transform
|
||||||
|
dict(type='LoadAnnotations'), # 从文件中加载标注
|
||||||
|
dict(type='PackSegInputs') # 打包数据
|
||||||
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # 多尺度预测缩放比例
|
||||||
|
tta_pipeline = [ # 多尺度测试
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
|
train_dataloader = dict( # 训练数据加载器
|
||||||
|
batch_size=2, # 训练时的数据批量大小
|
||||||
|
num_workers=4, # 数据加载线程数
|
||||||
|
persistent_workers=True, # 是否持久化线程
|
||||||
|
sampler=dict(type='InfiniteSampler', shuffle=True), # 无限采样器
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type, # 数据集类名
|
||||||
|
data_root=data_root, # 数据集根目录
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='img_dir/train',
|
||||||
|
seg_map_path='ann_dir/train'), # 训练集图像和标注路径
|
||||||
|
pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=1, # 验证时的数据批量大小
|
||||||
|
num_workers=4, # 数据加载线程数
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
test_dataloader = val_dataloader
|
||||||
|
|
||||||
|
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||||
|
test_evaluator = val_evaluator
|
@ -0,0 +1,15 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
|
||||||
|
'./_base_/datasets/gid.py', '../../../configs/_base_/default_runtime.py',
|
||||||
|
'../../../configs/_base_/schedules/schedule_240k.py'
|
||||||
|
]
|
||||||
|
custom_imports = dict(imports=['projects.gid_dataset.mmseg.datasets.gid'])
|
||||||
|
|
||||||
|
crop_size = (256, 256)
|
||||||
|
data_preprocessor = dict(size=crop_size)
|
||||||
|
model = dict(
|
||||||
|
data_preprocessor=data_preprocessor,
|
||||||
|
pretrained='open-mmlab://resnet101_v1c',
|
||||||
|
backbone=dict(depth=101),
|
||||||
|
decode_head=dict(num_classes=6),
|
||||||
|
auxiliary_head=dict(num_classes=6))
|
55
projects/gid_dataset/mmseg/datasets/gid.py
Normal file
55
projects/gid_dataset/mmseg/datasets/gid.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mmseg.datasets.basesegdataset import BaseSegDataset
|
||||||
|
from mmseg.registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
# 注册数据集类
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class GID_Dataset(BaseSegDataset):
|
||||||
|
"""Gaofen Image Dataset (GID)
|
||||||
|
|
||||||
|
Dataset paper link:
|
||||||
|
https://www.sciencedirect.com/science/article/pii/S0034425719303414
|
||||||
|
https://x-ytong.github.io/project/GID.html
|
||||||
|
|
||||||
|
GID 6 classes: others, built-up, farmland, forest, meadow, water
|
||||||
|
|
||||||
|
In this example, select 15 images from GID dataset as training set,
|
||||||
|
and select 5 images as validation set.
|
||||||
|
The selected images are listed as follows:
|
||||||
|
|
||||||
|
GF2_PMS1__L1A0000647767-MSS1
|
||||||
|
GF2_PMS1__L1A0001064454-MSS1
|
||||||
|
GF2_PMS1__L1A0001348919-MSS1
|
||||||
|
GF2_PMS1__L1A0001680851-MSS1
|
||||||
|
GF2_PMS1__L1A0001680853-MSS1
|
||||||
|
GF2_PMS1__L1A0001680857-MSS1
|
||||||
|
GF2_PMS1__L1A0001757429-MSS1
|
||||||
|
GF2_PMS2__L1A0000607681-MSS2
|
||||||
|
GF2_PMS2__L1A0000635115-MSS2
|
||||||
|
GF2_PMS2__L1A0000658637-MSS2
|
||||||
|
GF2_PMS2__L1A0001206072-MSS2
|
||||||
|
GF2_PMS2__L1A0001471436-MSS2
|
||||||
|
GF2_PMS2__L1A0001642620-MSS2
|
||||||
|
GF2_PMS2__L1A0001787089-MSS2
|
||||||
|
GF2_PMS2__L1A0001838560-MSS2
|
||||||
|
|
||||||
|
The ``img_suffix`` is fixed to '.tif' and ``seg_map_suffix`` is
|
||||||
|
fixed to '.tif' for GID.
|
||||||
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
|
classes=('Others', 'Built-up', 'Farmland', 'Forest', 'Meadow',
|
||||||
|
'Water'),
|
||||||
|
palette=[[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 255, 255],
|
||||||
|
[255, 255, 0], [0, 0, 255]])
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
img_suffix='.png',
|
||||||
|
seg_map_suffix='.png',
|
||||||
|
reduce_zero_label=None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
img_suffix=img_suffix,
|
||||||
|
seg_map_suffix=seg_map_suffix,
|
||||||
|
reduce_zero_label=reduce_zero_label,
|
||||||
|
**kwargs)
|
181
projects/gid_dataset/tools/dataset_converters/gid.py
Normal file
181
projects/gid_dataset/tools/dataset_converters/gid.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
from mmengine.utils import ProgressBar, mkdir_or_exist
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Convert GID dataset to mmsegmentation format')
|
||||||
|
parser.add_argument('dataset_img_path', help='GID images folder path')
|
||||||
|
parser.add_argument('dataset_label_path', help='GID labels folder path')
|
||||||
|
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
||||||
|
parser.add_argument(
|
||||||
|
'-o', '--out_dir', help='output path', default='data/gid')
|
||||||
|
parser.add_argument(
|
||||||
|
'--clip_size',
|
||||||
|
type=int,
|
||||||
|
help='clipped size of image after preparation',
|
||||||
|
default=256)
|
||||||
|
parser.add_argument(
|
||||||
|
'--stride_size',
|
||||||
|
type=int,
|
||||||
|
help='stride of clipping original images',
|
||||||
|
default=256)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
GID_COLORMAP = dict(
|
||||||
|
Background=(0, 0, 0), # 0-背景-黑色
|
||||||
|
Building=(255, 0, 0), # 1-建筑-红色
|
||||||
|
Farmland=(0, 255, 0), # 2-农田-绿色
|
||||||
|
Forest=(0, 0, 255), # 3-森林-蓝色
|
||||||
|
Meadow=(255, 255, 0), # 4-草地-黄色
|
||||||
|
Water=(0, 0, 255) # 5-水-蓝色
|
||||||
|
)
|
||||||
|
palette = list(GID_COLORMAP.values())
|
||||||
|
classes = list(GID_COLORMAP.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# 用列表来存一个 RGB 和一个类别的对应
|
||||||
|
def colormap2label(palette):
|
||||||
|
colormap2label_list = np.zeros(256**3, dtype=np.longlong)
|
||||||
|
for i, colormap in enumerate(palette):
|
||||||
|
colormap2label_list[(colormap[0] * 256 + colormap[1]) * 256 +
|
||||||
|
colormap[2]] = i
|
||||||
|
return colormap2label_list
|
||||||
|
|
||||||
|
|
||||||
|
# 给定那个列表,和vis_png然后生成masks_png
|
||||||
|
def label_indices(RGB_label, colormap2label_list):
|
||||||
|
RGB_label = RGB_label.astype('int32')
|
||||||
|
idx = (RGB_label[:, :, 0] * 256 +
|
||||||
|
RGB_label[:, :, 1]) * 256 + RGB_label[:, :, 2]
|
||||||
|
return colormap2label_list[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def RGB2mask(RGB_label, colormap2label_list):
|
||||||
|
mask_label = label_indices(RGB_label, colormap2label_list)
|
||||||
|
return mask_label
|
||||||
|
|
||||||
|
|
||||||
|
colormap2label_list = colormap2label(palette)
|
||||||
|
|
||||||
|
|
||||||
|
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
|
||||||
|
"""Original image of GID dataset is very large, thus pre-processing of them
|
||||||
|
is adopted.
|
||||||
|
|
||||||
|
Given fixed clip size and stride size to generate
|
||||||
|
clipped image, the intersection of width and height is determined.
|
||||||
|
For example, given one 6800 x 7200 original image, the clip size is
|
||||||
|
256 and stride size is 256, thus it would generate 29 x 27 = 783 images
|
||||||
|
whose size are all 256 x 256.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image = mmcv.imread(image_path, channel_order='rgb')
|
||||||
|
# image = mmcv.bgr2gray(image)
|
||||||
|
|
||||||
|
h, w, c = image.shape
|
||||||
|
clip_size = args.clip_size
|
||||||
|
stride_size = args.stride_size
|
||||||
|
|
||||||
|
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
|
||||||
|
(h - clip_size) /
|
||||||
|
stride_size) * stride_size + clip_size >= h else math.ceil(
|
||||||
|
(h - clip_size) / stride_size) + 1
|
||||||
|
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
|
||||||
|
(w - clip_size) /
|
||||||
|
stride_size) * stride_size + clip_size >= w else math.ceil(
|
||||||
|
(w - clip_size) / stride_size) + 1
|
||||||
|
|
||||||
|
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
||||||
|
xmin = x * clip_size
|
||||||
|
ymin = y * clip_size
|
||||||
|
|
||||||
|
xmin = xmin.ravel()
|
||||||
|
ymin = ymin.ravel()
|
||||||
|
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
|
||||||
|
np.zeros_like(xmin))
|
||||||
|
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
|
||||||
|
np.zeros_like(ymin))
|
||||||
|
boxes = np.stack([
|
||||||
|
xmin + xmin_offset, ymin + ymin_offset,
|
||||||
|
np.minimum(xmin + clip_size, w),
|
||||||
|
np.minimum(ymin + clip_size, h)
|
||||||
|
],
|
||||||
|
axis=1)
|
||||||
|
|
||||||
|
if to_label:
|
||||||
|
image = RGB2mask(image, colormap2label_list)
|
||||||
|
|
||||||
|
for count, box in enumerate(boxes):
|
||||||
|
start_x, start_y, end_x, end_y = box
|
||||||
|
clipped_image = image[start_y:end_y,
|
||||||
|
start_x:end_x] if to_label else image[
|
||||||
|
start_y:end_y, start_x:end_x, :]
|
||||||
|
img_name = osp.basename(image_path).replace('.tif', '')
|
||||||
|
img_name = img_name.replace('_label', '')
|
||||||
|
if count % 3 == 0:
|
||||||
|
mmcv.imwrite(
|
||||||
|
clipped_image.astype(np.uint8),
|
||||||
|
osp.join(
|
||||||
|
clip_save_dir.replace('train', 'val'),
|
||||||
|
f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
||||||
|
else:
|
||||||
|
mmcv.imwrite(
|
||||||
|
clipped_image.astype(np.uint8),
|
||||||
|
osp.join(
|
||||||
|
clip_save_dir,
|
||||||
|
f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
"""
|
||||||
|
According to this paper: https://ieeexplore.ieee.org/document/9343296/
|
||||||
|
select 15 images contained in GID, , which cover the whole six
|
||||||
|
categories, to generate train set and validation set.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if args.out_dir is None:
|
||||||
|
out_dir = osp.join('data', 'gid')
|
||||||
|
else:
|
||||||
|
out_dir = args.out_dir
|
||||||
|
|
||||||
|
print('Making directories...')
|
||||||
|
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
||||||
|
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
||||||
|
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
||||||
|
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
||||||
|
|
||||||
|
src_path_list = glob.glob(os.path.join(args.dataset_img_path, '*.tif'))
|
||||||
|
print(f'Find {len(src_path_list)} pictures')
|
||||||
|
|
||||||
|
prog_bar = ProgressBar(len(src_path_list))
|
||||||
|
|
||||||
|
dst_img_dir = osp.join(out_dir, 'img_dir', 'train')
|
||||||
|
dst_label_dir = osp.join(out_dir, 'ann_dir', 'train')
|
||||||
|
|
||||||
|
for i, img_path in enumerate(src_path_list):
|
||||||
|
label_path = osp.join(
|
||||||
|
args.dataset_label_path,
|
||||||
|
osp.basename(img_path.replace('.tif', '_label.tif')))
|
||||||
|
|
||||||
|
clip_big_image(img_path, dst_img_dir, args, to_label=False)
|
||||||
|
clip_big_image(label_path, dst_label_dir, args, to_label=True)
|
||||||
|
prog_bar.update()
|
||||||
|
|
||||||
|
print('Done!')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,75 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# select 15 images from GID dataset
|
||||||
|
|
||||||
|
img_list = [
|
||||||
|
'GF2_PMS1__L1A0000647767-MSS1.tif', 'GF2_PMS1__L1A0001064454-MSS1.tif',
|
||||||
|
'GF2_PMS1__L1A0001348919-MSS1.tif', 'GF2_PMS1__L1A0001680851-MSS1.tif',
|
||||||
|
'GF2_PMS1__L1A0001680853-MSS1.tif', 'GF2_PMS1__L1A0001680857-MSS1.tif',
|
||||||
|
'GF2_PMS1__L1A0001757429-MSS1.tif', 'GF2_PMS2__L1A0000607681-MSS2.tif',
|
||||||
|
'GF2_PMS2__L1A0000635115-MSS2.tif', 'GF2_PMS2__L1A0000658637-MSS2.tif',
|
||||||
|
'GF2_PMS2__L1A0001206072-MSS2.tif', 'GF2_PMS2__L1A0001471436-MSS2.tif',
|
||||||
|
'GF2_PMS2__L1A0001642620-MSS2.tif', 'GF2_PMS2__L1A0001787089-MSS2.tif',
|
||||||
|
'GF2_PMS2__L1A0001838560-MSS2.tif'
|
||||||
|
]
|
||||||
|
|
||||||
|
labels_list = [
|
||||||
|
'GF2_PMS1__L1A0000647767-MSS1_label.tif',
|
||||||
|
'GF2_PMS1__L1A0001064454-MSS1_label.tif',
|
||||||
|
'GF2_PMS1__L1A0001348919-MSS1_label.tif',
|
||||||
|
'GF2_PMS1__L1A0001680851-MSS1_label.tif',
|
||||||
|
'GF2_PMS1__L1A0001680853-MSS1_label.tif',
|
||||||
|
'GF2_PMS1__L1A0001680857-MSS1_label.tif',
|
||||||
|
'GF2_PMS1__L1A0001757429-MSS1_label.tif',
|
||||||
|
'GF2_PMS2__L1A0000607681-MSS2_label.tif',
|
||||||
|
'GF2_PMS2__L1A0000635115-MSS2_label.tif',
|
||||||
|
'GF2_PMS2__L1A0000658637-MSS2_label.tif',
|
||||||
|
'GF2_PMS2__L1A0001206072-MSS2_label.tif',
|
||||||
|
'GF2_PMS2__L1A0001471436-MSS2_label.tif',
|
||||||
|
'GF2_PMS2__L1A0001642620-MSS2_label.tif',
|
||||||
|
'GF2_PMS2__L1A0001787089-MSS2_label.tif',
|
||||||
|
'GF2_PMS2__L1A0001838560-MSS2_label.tif'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='From 150 images of GID dataset to select 15 images')
|
||||||
|
parser.add_argument('dataset_img_dir', help='150 GID images folder path')
|
||||||
|
parser.add_argument('dataset_label_dir', help='150 GID labels folder path')
|
||||||
|
|
||||||
|
parser.add_argument('dest_img_dir', help='15 GID images folder path')
|
||||||
|
parser.add_argument('dest_label_dir', help='15 GID labels folder path')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""This script is used to select 15 images from GID dataset, According to
|
||||||
|
paper: https://ieeexplore.ieee.org/document/9343296/"""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
img_path = args.dataset_img_dir
|
||||||
|
label_path = args.dataset_label_dir
|
||||||
|
|
||||||
|
dest_img_dir = args.dest_img_dir
|
||||||
|
dest_label_dir = args.dest_label_dir
|
||||||
|
|
||||||
|
# copy images of 'img_list' to 'desr_dir'
|
||||||
|
print('Copy images of img_list to desr_dir ing...')
|
||||||
|
for img in img_list:
|
||||||
|
shutil.copy(os.path.join(img_path, img), dest_img_dir)
|
||||||
|
print('Done!')
|
||||||
|
|
||||||
|
print('copy labels of labels_list to desr_dir ing...')
|
||||||
|
for label in labels_list:
|
||||||
|
shutil.copy(os.path.join(label_path, label), dest_label_dir)
|
||||||
|
print('Done!')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
53
projects/gid_dataset/user_guides/2_dataset_prepare.md
Normal file
53
projects/gid_dataset/user_guides/2_dataset_prepare.md
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
## Gaofen Image Dataset (GID)
|
||||||
|
|
||||||
|
- GID 数据集可在[此处](https://x-ytong.github.io/project/GID.html)进行下载。
|
||||||
|
- GID 数据集包含 150 张 6800x7200 的大尺寸图像,标签为 RGB 标签。
|
||||||
|
- 根据[文献](https://ieeexplore.ieee.org/document/9343296/),此处选择 15 张图像生成训练集和验证集,该 15 张图像包含了所有六类信息。所选的图像名称如下:
|
||||||
|
|
||||||
|
```None
|
||||||
|
GF2_PMS1__L1A0000647767-MSS1
|
||||||
|
GF2_PMS1__L1A0001064454-MSS1
|
||||||
|
GF2_PMS1__L1A0001348919-MSS1
|
||||||
|
GF2_PMS1__L1A0001680851-MSS1
|
||||||
|
GF2_PMS1__L1A0001680853-MSS1
|
||||||
|
GF2_PMS1__L1A0001680857-MSS1
|
||||||
|
GF2_PMS1__L1A0001757429-MSS1
|
||||||
|
GF2_PMS2__L1A0000607681-MSS2
|
||||||
|
GF2_PMS2__L1A0000635115-MSS2
|
||||||
|
GF2_PMS2__L1A0000658637-MSS2
|
||||||
|
GF2_PMS2__L1A0001206072-MSS2
|
||||||
|
GF2_PMS2__L1A0001471436-MSS2
|
||||||
|
GF2_PMS2__L1A0001642620-MSS2
|
||||||
|
GF2_PMS2__L1A0001787089-MSS2
|
||||||
|
GF2_PMS2__L1A0001838560-MSS2
|
||||||
|
```
|
||||||
|
|
||||||
|
这里也提供了一个脚本来方便的筛选出15张图像,
|
||||||
|
|
||||||
|
```
|
||||||
|
python projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py {150 张图像的路径} {150 张标签的路径} {15 张图像的路径} {15 张标签的路径}
|
||||||
|
```
|
||||||
|
|
||||||
|
在选择出 15 张图像后,执行以下命令进行裁切及标签的转换,需要修改为您所存储 15 张图像及标签的路径。
|
||||||
|
|
||||||
|
```
|
||||||
|
python projects/gid_dataset/tools/dataset_converters/gid.py {15 张图像的路径} {15 张标签的路径}
|
||||||
|
```
|
||||||
|
|
||||||
|
完成裁切后的 GID 数据结构如下:
|
||||||
|
|
||||||
|
```none
|
||||||
|
mmsegmentation
|
||||||
|
├── mmseg
|
||||||
|
├── tools
|
||||||
|
├── configs
|
||||||
|
├── data
|
||||||
|
│ ├── gid
|
||||||
|
│ │ ├── ann_dir
|
||||||
|
| │ │ │ ├── train
|
||||||
|
| │ │ │ ├── val
|
||||||
|
│ │ ├── img_dir
|
||||||
|
| │ │ │ ├── train
|
||||||
|
| │ │ │ ├── val
|
||||||
|
|
||||||
|
```
|
Loading…
x
Reference in New Issue
Block a user