mmsegmentation/projects/medical/2d_image/tools/split_seg_dataset.py

43 lines
1.6 KiB
Python

import argparse
import glob
import os
from sklearn.model_selection import train_test_split
def save_anno(img_list, file_path, remove_suffix=True):
if remove_suffix:
img_list = [
'/'.join(img_path.split('/')[-2:]) for img_path in img_list
]
img_list = [
'.'.join(img_path.split('.')[:-1]) for img_path in img_list
]
with open(file_path, 'w') as file_:
for x in list(img_list):
file_.write(x + '\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='data/')
args = parser.parse_args()
data_root = args.data_root
if os.path.exists(os.path.join(data_root, 'masks/val')):
x_val = sorted(glob.glob(data_root + '/images/val/*.png'))
save_anno(x_val, data_root + '/val.txt')
if os.path.exists(os.path.join(data_root, 'masks/test')):
x_test = sorted(glob.glob(data_root + '/images/test/*.png'))
save_anno(x_test, data_root + '/test.txt')
if not os.path.exists(os.path.join(
data_root, 'masks/val')) and not os.path.exists(
os.path.join(data_root, 'masks/test')):
all_imgs = sorted(glob.glob(data_root + '/images/train/*.png'))
x_train, x_val = train_test_split(
all_imgs, test_size=0.2, random_state=0)
save_anno(x_train, data_root + '/train.txt')
save_anno(x_val, data_root + '/val.txt')
else:
x_train = sorted(glob.glob(data_root + '/images/train/*.png'))
save_anno(x_train, data_root + '/train.txt')