mirror of https://github.com/PyRetri/PyRetri.git
38 lines
1.4 KiB
Python
38 lines
1.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
import os
|
|
from shutil import copyfile
|
|
|
|
|
|
def split_dataset(dataset_path: str, split_file: str) -> None:
|
|
"""
|
|
Split the dataset according to the given splitting rules.
|
|
|
|
Args:
|
|
dataset_path (str): the path of the dataset.
|
|
split_file (str): the path of the file containing the splitting rules.
|
|
"""
|
|
|
|
with open(split_file, 'r') as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
path = line.strip('\n').split(' ')[0]
|
|
is_gallery = line.strip('\n').split(' ')[1]
|
|
if is_gallery == '0':
|
|
src = os.path.join(dataset_path, path)
|
|
dst = src.replace(path.split('/')[0], 'query')
|
|
dst_index = len(dst.split('/')[-1])
|
|
dst_dir = dst[:len(dst) - dst_index]
|
|
if not os.path.isdir(dst_dir):
|
|
os.makedirs(dst_dir)
|
|
if not os.path.exists(dst):
|
|
os.symlink(src, dst)
|
|
elif is_gallery == '1':
|
|
src = os.path.join(dataset_path, path)
|
|
dst = src.replace(path.split('/')[0], 'gallery')
|
|
dst_index = len(dst.split('/')[-1])
|
|
dst_dir = dst[:len(dst) - dst_index]
|
|
if not os.path.isdir(dst_dir):
|
|
os.makedirs(dst_dir)
|
|
if not os.path.exists(dst):
|
|
os.symlink(src, dst)
|