2021-09-29 19:27:46 +08:00
|
|
|
|
# coding:utf8
|
|
|
|
|
import os
|
|
|
|
|
import shutil
|
|
|
|
|
import random
|
|
|
|
|
import argparse
|
|
|
|
|
|
2021-10-08 15:31:48 +08:00
|
|
|
|
|
2021-10-21 18:39:14 +08:00
|
|
|
|
# 删除划分的训练集、验证集、测试集文件夹,重新创建一个空的文件夹
|
2021-09-29 19:27:46 +08:00
|
|
|
|
def isCreateOrDeleteFolder(path, flag):
|
|
|
|
|
flagPath = os.path.join(path, flag)
|
2021-10-21 18:39:14 +08:00
|
|
|
|
|
2021-09-29 19:27:46 +08:00
|
|
|
|
if os.path.exists(flagPath):
|
|
|
|
|
shutil.rmtree(flagPath)
|
2021-10-21 18:39:14 +08:00
|
|
|
|
|
2021-09-29 19:27:46 +08:00
|
|
|
|
os.makedirs(flagPath)
|
|
|
|
|
flagAbsPath = os.path.abspath(flagPath)
|
|
|
|
|
return flagAbsPath
|
|
|
|
|
|
|
|
|
|
|
2023-10-13 04:27:26 +02:00
|
|
|
|
def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_path, train_txt, val_txt, test_txt, flag):
|
|
|
|
|
|
|
|
|
|
data_abs_path = os.path.abspath(root)
|
|
|
|
|
label_file_name = args.detLabelFileName if flag == "det" else args.recLabelFileName
|
|
|
|
|
label_file_path = os.path.join(data_abs_path, label_file_name)
|
|
|
|
|
|
|
|
|
|
with open(label_file_path, "r", encoding="UTF-8") as label_file:
|
|
|
|
|
label_file_content = label_file.readlines()
|
|
|
|
|
random.shuffle(label_file_content)
|
|
|
|
|
label_record_len = len(label_file_content)
|
|
|
|
|
|
|
|
|
|
for index, label_record_info in enumerate(label_file_content):
|
|
|
|
|
image_relative_path, image_label = label_record_info.split('\t')
|
|
|
|
|
image_name = os.path.basename(image_relative_path)
|
|
|
|
|
|
|
|
|
|
if flag == "det":
|
|
|
|
|
image_path = os.path.join(data_abs_path, image_name)
|
|
|
|
|
elif flag == "rec":
|
|
|
|
|
image_path = os.path.join(data_abs_path, args.recImageDirName, image_name)
|
|
|
|
|
|
|
|
|
|
train_val_test_ratio = args.trainValTestRatio.split(":")
|
|
|
|
|
train_ratio = eval(train_val_test_ratio[0]) / 10
|
|
|
|
|
val_ratio = train_ratio + eval(train_val_test_ratio[1]) / 10
|
|
|
|
|
cur_ratio = index / label_record_len
|
|
|
|
|
|
|
|
|
|
if cur_ratio < train_ratio:
|
|
|
|
|
image_copy_path = os.path.join(abs_train_root_path, image_name)
|
|
|
|
|
shutil.copy(image_path, image_copy_path)
|
2023-11-22 20:10:12 +08:00
|
|
|
|
train_txt.write("{}\t{}".format(image_copy_path, image_label))
|
2023-10-13 04:27:26 +02:00
|
|
|
|
elif cur_ratio >= train_ratio and cur_ratio < val_ratio:
|
|
|
|
|
image_copy_path = os.path.join(abs_val_root_path, image_name)
|
|
|
|
|
shutil.copy(image_path, image_copy_path)
|
2023-11-22 20:10:12 +08:00
|
|
|
|
val_txt.write("{}\t{}".format(image_copy_path, image_label))
|
2023-10-13 04:27:26 +02:00
|
|
|
|
else:
|
|
|
|
|
image_copy_path = os.path.join(abs_test_root_path, image_name)
|
|
|
|
|
shutil.copy(image_path, image_copy_path)
|
2023-11-22 20:10:12 +08:00
|
|
|
|
test_txt.write("{}\t{}".format(image_copy_path, image_label))
|
2021-09-29 19:27:46 +08:00
|
|
|
|
|
|
|
|
|
|
2021-10-08 15:31:48 +08:00
|
|
|
|
# 删掉存在的文件
|
|
|
|
|
def removeFile(path):
|
|
|
|
|
if os.path.exists(path):
|
|
|
|
|
os.remove(path)
|
|
|
|
|
|
|
|
|
|
|
2021-09-29 19:27:46 +08:00
|
|
|
|
def genDetRecTrainVal(args):
|
|
|
|
|
detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train")
|
|
|
|
|
detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val")
|
2021-10-21 18:39:14 +08:00
|
|
|
|
detAbsTestRootPath = isCreateOrDeleteFolder(args.detRootPath, "test")
|
2021-09-29 19:27:46 +08:00
|
|
|
|
recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train")
|
|
|
|
|
recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val")
|
2021-10-21 18:39:14 +08:00
|
|
|
|
recAbsTestRootPath = isCreateOrDeleteFolder(args.recRootPath, "test")
|
|
|
|
|
|
2021-10-08 15:31:48 +08:00
|
|
|
|
removeFile(os.path.join(args.detRootPath, "train.txt"))
|
|
|
|
|
removeFile(os.path.join(args.detRootPath, "val.txt"))
|
2021-10-21 18:39:14 +08:00
|
|
|
|
removeFile(os.path.join(args.detRootPath, "test.txt"))
|
2021-10-08 15:31:48 +08:00
|
|
|
|
removeFile(os.path.join(args.recRootPath, "train.txt"))
|
|
|
|
|
removeFile(os.path.join(args.recRootPath, "val.txt"))
|
2021-10-21 18:39:14 +08:00
|
|
|
|
removeFile(os.path.join(args.recRootPath, "test.txt"))
|
|
|
|
|
|
2021-09-29 19:27:46 +08:00
|
|
|
|
detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
|
|
|
|
|
detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
|
2021-10-21 18:39:14 +08:00
|
|
|
|
detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
|
2021-09-29 19:27:46 +08:00
|
|
|
|
recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
|
|
|
|
|
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
|
2021-10-21 18:39:14 +08:00
|
|
|
|
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
|
|
|
|
|
|
2022-02-10 22:40:19 +08:00
|
|
|
|
splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
|
|
|
|
|
detTestTxt, "det")
|
|
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(args.datasetRootPath):
|
2021-09-29 19:27:46 +08:00
|
|
|
|
for dir in dirs:
|
2022-02-10 22:40:19 +08:00
|
|
|
|
if dir == 'crop_img':
|
|
|
|
|
splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
|
|
|
|
|
recTestTxt, "rec")
|
|
|
|
|
else:
|
|
|
|
|
continue
|
2021-09-29 19:27:46 +08:00
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
2022-02-10 22:40:19 +08:00
|
|
|
|
|
2021-09-29 19:27:46 +08:00
|
|
|
|
if __name__ == "__main__":
|
2021-10-21 18:39:14 +08:00
|
|
|
|
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
|
2021-09-29 19:27:46 +08:00
|
|
|
|
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
|
2021-10-21 18:39:14 +08:00
|
|
|
|
# 如此会有多个标注好的图像文件夹汇总并划分训练集、验证集、测试集的需求
|
2021-09-29 19:27:46 +08:00
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument(
|
2021-10-21 18:39:14 +08:00
|
|
|
|
"--trainValTestRatio",
|
|
|
|
|
type=str,
|
|
|
|
|
default="6:2:2",
|
|
|
|
|
help="ratio of trainset:valset:testset")
|
2021-09-29 19:27:46 +08:00
|
|
|
|
parser.add_argument(
|
2022-02-10 22:40:19 +08:00
|
|
|
|
"--datasetRootPath",
|
2021-09-29 19:27:46 +08:00
|
|
|
|
type=str,
|
2022-02-10 22:40:19 +08:00
|
|
|
|
default="../train_data/",
|
2021-09-29 19:27:46 +08:00
|
|
|
|
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--detRootPath",
|
|
|
|
|
type=str,
|
2021-10-21 18:39:14 +08:00
|
|
|
|
default="../train_data/det",
|
2021-09-29 19:27:46 +08:00
|
|
|
|
help="the path where the divided detection dataset is placed")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--recRootPath",
|
|
|
|
|
type=str,
|
2021-10-21 18:39:14 +08:00
|
|
|
|
default="../train_data/rec",
|
2021-09-29 19:27:46 +08:00
|
|
|
|
help="the path where the divided recognition dataset is placed"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--detLabelFileName",
|
|
|
|
|
type=str,
|
|
|
|
|
default="Label.txt",
|
|
|
|
|
help="the name of the detection annotation file")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--recLabelFileName",
|
|
|
|
|
type=str,
|
|
|
|
|
default="rec_gt.txt",
|
|
|
|
|
help="the name of the recognition annotation file"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--recImageDirName",
|
|
|
|
|
type=str,
|
|
|
|
|
default="crop_img",
|
|
|
|
|
help="the name of the folder where the cropped recognition dataset is located"
|
|
|
|
|
)
|
|
|
|
|
args = parser.parse_args()
|
2023-10-13 04:27:26 +02:00
|
|
|
|
genDetRecTrainVal(args)
|