mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
Merge pull request #4230 from MrCuiHao/dygraph
使用PPOCRLabel标注多个图像文件夹后,此脚本用于汇总按照比例划分文本检测和文本识别的训练集和验证集
This commit is contained in:
commit
61e40f66bc
@ -207,6 +207,24 @@ For some data that are difficult to recognize, the recognition results will not
|
||||
pip install opencv-contrib-python-headless==4.2.0.32
|
||||
```
|
||||
|
||||
### Dataset division
|
||||
|
||||
- Enter the following command in the terminal to execute the dataset division script:
|
||||
```
|
||||
cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder
|
||||
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec
|
||||
```
|
||||
|
||||
- Parameter Description:
|
||||
|
||||
trainValTestRatio is the division ratio of the number of images in the training set, validation set, and test set, set according to your actual situation, the default is 6:2:2
|
||||
|
||||
labelRootPath is the storage path of the dataset labeled by PPOCRLabel, the default is ../train_data/label
|
||||
|
||||
detRootPath is the path where the text detection dataset is divided according to the dataset marked by PPOCRLabel. The default is ../train_data/det
|
||||
|
||||
recRootPath is the path where the character recognition dataset is divided according to the dataset marked by PPOCRLabel. The default is ../train_data/rec
|
||||
|
||||
### Related
|
||||
|
||||
1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg)
|
@ -193,7 +193,23 @@ PPOCRLabel支持三种导出方式:
|
||||
```
|
||||
pip install opencv-contrib-python-headless==4.2.0.32
|
||||
```
|
||||
### 数据集划分
|
||||
- 在终端中输入以下命令执行数据集划分脚本:
|
||||
```
|
||||
cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
|
||||
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --labelRootPath ../train_data/label --detRootPath ../train_data/det --recRootPath ../train_data/rec
|
||||
```
|
||||
- 参数说明:
|
||||
|
||||
trainValTestRatio是训练集、验证集、测试集的图像数量划分比例,根据你的实际情况设定,默认是6:2:2
|
||||
|
||||
labelRootPath是PPOCRLabel标注的数据集存放路径,默认是../train_data/label
|
||||
|
||||
detRootPath是根据PPOCRLabel标注的数据集划分后的文本检测数据集存放的路径,默认是../train_data/det
|
||||
|
||||
recRootPath是根据PPOCRLabel标注的数据集划分后的字符识别数据集存放的路径,默认是../train_data/rec
|
||||
|
||||
|
||||
### 4. 参考资料
|
||||
|
||||
1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg)
|
||||
|
147
PPOCRLabel/gen_ocr_train_val_test.py
Normal file
147
PPOCRLabel/gen_ocr_train_val_test.py
Normal file
@ -0,0 +1,147 @@
|
||||
# coding:utf8
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
import argparse
|
||||
|
||||
|
||||
# 删除划分的训练集、验证集、测试集文件夹,重新创建一个空的文件夹
|
||||
def isCreateOrDeleteFolder(path, flag):
|
||||
flagPath = os.path.join(path, flag)
|
||||
|
||||
if os.path.exists(flagPath):
|
||||
shutil.rmtree(flagPath)
|
||||
|
||||
os.makedirs(flagPath)
|
||||
flagAbsPath = os.path.abspath(flagPath)
|
||||
return flagAbsPath
|
||||
|
||||
|
||||
def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
|
||||
# 按照指定的比例划分训练集、验证集、测试集
|
||||
labelPath = os.path.join(root, dir)
|
||||
labelAbsPath = os.path.abspath(labelPath)
|
||||
|
||||
if flag == "det":
|
||||
labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName)
|
||||
elif flag == "rec":
|
||||
labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName)
|
||||
|
||||
labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
|
||||
labelFileContent = labelFileRead.readlines()
|
||||
random.shuffle(labelFileContent)
|
||||
labelRecordLen = len(labelFileContent)
|
||||
|
||||
for index, labelRecordInfo in enumerate(labelFileContent):
|
||||
imageRelativePath = labelRecordInfo.split('\t')[0]
|
||||
imageLabel = labelRecordInfo.split('\t')[1]
|
||||
imageName = os.path.basename(imageRelativePath)
|
||||
|
||||
if flag == "det":
|
||||
imagePath = os.path.join(labelAbsPath, imageName)
|
||||
elif flag == "rec":
|
||||
imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
|
||||
|
||||
# 按预设的比例划分训练集、验证集、测试集
|
||||
trainValTestRatio = args.trainValTestRatio.split(":")
|
||||
trainRatio = eval(trainValTestRatio[0]) / 10
|
||||
valRatio = trainRatio + eval(trainValTestRatio[1]) / 10
|
||||
curRatio = index / labelRecordLen
|
||||
|
||||
if curRatio < trainRatio:
|
||||
imageCopyPath = os.path.join(absTrainRootPath, imageName)
|
||||
shutil.copy(imagePath, imageCopyPath)
|
||||
trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
|
||||
elif curRatio >= trainRatio and curRatio < valRatio:
|
||||
imageCopyPath = os.path.join(absValRootPath, imageName)
|
||||
shutil.copy(imagePath, imageCopyPath)
|
||||
valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
|
||||
else:
|
||||
imageCopyPath = os.path.join(absTestRootPath, imageName)
|
||||
shutil.copy(imagePath, imageCopyPath)
|
||||
testTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
|
||||
|
||||
|
||||
# 删掉存在的文件
|
||||
def removeFile(path):
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
|
||||
def genDetRecTrainVal(args):
|
||||
detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train")
|
||||
detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val")
|
||||
detAbsTestRootPath = isCreateOrDeleteFolder(args.detRootPath, "test")
|
||||
recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train")
|
||||
recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val")
|
||||
recAbsTestRootPath = isCreateOrDeleteFolder(args.recRootPath, "test")
|
||||
|
||||
removeFile(os.path.join(args.detRootPath, "train.txt"))
|
||||
removeFile(os.path.join(args.detRootPath, "val.txt"))
|
||||
removeFile(os.path.join(args.detRootPath, "test.txt"))
|
||||
removeFile(os.path.join(args.recRootPath, "train.txt"))
|
||||
removeFile(os.path.join(args.recRootPath, "val.txt"))
|
||||
removeFile(os.path.join(args.recRootPath, "test.txt"))
|
||||
|
||||
detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
|
||||
detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
|
||||
detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
|
||||
recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
|
||||
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
|
||||
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
|
||||
|
||||
for root, dirs, files in os.walk(args.labelRootPath):
|
||||
for dir in dirs:
|
||||
splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
|
||||
detTestTxt, "det")
|
||||
splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
|
||||
recTestTxt, "rec")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
|
||||
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
|
||||
# 如此会有多个标注好的图像文件夹汇总并划分训练集、验证集、测试集的需求
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--trainValTestRatio",
|
||||
type=str,
|
||||
default="6:2:2",
|
||||
help="ratio of trainset:valset:testset")
|
||||
parser.add_argument(
|
||||
"--labelRootPath",
|
||||
type=str,
|
||||
default="../train_data/label",
|
||||
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detRootPath",
|
||||
type=str,
|
||||
default="../train_data/det",
|
||||
help="the path where the divided detection dataset is placed")
|
||||
parser.add_argument(
|
||||
"--recRootPath",
|
||||
type=str,
|
||||
default="../train_data/rec",
|
||||
help="the path where the divided recognition dataset is placed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detLabelFileName",
|
||||
type=str,
|
||||
default="Label.txt",
|
||||
help="the name of the detection annotation file")
|
||||
parser.add_argument(
|
||||
"--recLabelFileName",
|
||||
type=str,
|
||||
default="rec_gt.txt",
|
||||
help="the name of the recognition annotation file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recImageDirName",
|
||||
type=str,
|
||||
default="crop_img",
|
||||
help="the name of the folder where the cropped recognition dataset is located"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
genDetRecTrainVal(args)
|
Loading…
x
Reference in New Issue
Block a user