43 lines
1.6 KiB
Python
43 lines
1.6 KiB
Python
import argparse
|
|
import glob
|
|
import os
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
def save_anno(img_list, file_path, remove_suffix=True):
|
|
if remove_suffix:
|
|
img_list = [
|
|
'/'.join(img_path.split('/')[-2:]) for img_path in img_list
|
|
]
|
|
img_list = [
|
|
'.'.join(img_path.split('.')[:-1]) for img_path in img_list
|
|
]
|
|
with open(file_path, 'w') as file_:
|
|
for x in list(img_list):
|
|
file_.write(x + '\n')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--data_root', default='data/')
|
|
args = parser.parse_args()
|
|
data_root = args.data_root
|
|
if os.path.exists(os.path.join(data_root, 'masks/val')):
|
|
x_val = sorted(glob.glob(data_root + '/images/val/*.png'))
|
|
save_anno(x_val, data_root + '/val.txt')
|
|
if os.path.exists(os.path.join(data_root, 'masks/test')):
|
|
x_test = sorted(glob.glob(data_root + '/images/test/*.png'))
|
|
save_anno(x_test, data_root + '/test.txt')
|
|
if not os.path.exists(os.path.join(
|
|
data_root, 'masks/val')) and not os.path.exists(
|
|
os.path.join(data_root, 'masks/test')):
|
|
all_imgs = sorted(glob.glob(data_root + '/images/train/*.png'))
|
|
x_train, x_val = train_test_split(
|
|
all_imgs, test_size=0.2, random_state=0)
|
|
save_anno(x_train, data_root + '/train.txt')
|
|
save_anno(x_val, data_root + '/val.txt')
|
|
else:
|
|
x_train = sorted(glob.glob(data_root + '/images/train/*.png'))
|
|
save_anno(x_train, data_root + '/train.txt')
|