import argparse import glob import os from sklearn.model_selection import train_test_split def save_anno(img_list, file_path, remove_suffix=True): if remove_suffix: img_list = [ '/'.join(img_path.split('/')[-2:]) for img_path in img_list ] img_list = [ '.'.join(img_path.split('.')[:-1]) for img_path in img_list ] with open(file_path, 'w') as file_: for x in list(img_list): file_.write(x + '\n') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_root', default='data/') args = parser.parse_args() data_root = args.data_root if os.path.exists(os.path.join(data_root, 'masks/val')): x_val = sorted(glob.glob(data_root + '/images/val/*.png')) save_anno(x_val, data_root + '/val.txt') if os.path.exists(os.path.join(data_root, 'masks/test')): x_test = sorted(glob.glob(data_root + '/images/test/*.png')) save_anno(x_test, data_root + '/test.txt') if not os.path.exists(os.path.join( data_root, 'masks/val')) and not os.path.exists( os.path.join(data_root, 'masks/test')): all_imgs = sorted(glob.glob(data_root + '/images/train/*.png')) x_train, x_val = train_test_split( all_imgs, test_size=0.2, random_state=0) save_anno(x_train, data_root + '/train.txt') save_anno(x_val, data_root + '/val.txt') else: x_train = sorted(glob.glob(data_root + '/images/train/*.png')) save_anno(x_train, data_root + '/train.txt')