PaddleClas/deploy/lite_shitu/generate_json_config.py

139 lines
4.6 KiB
Python

import argparse
import json
import os
import yaml
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--yaml_path', type=str, default='../configs/inference_drink.yaml')
parser.add_argument(
'--img_dir',
type=str,
default=None,
help='The dir path for inference images')
parser.add_argument(
'--img_path',
type=str,
default=None,
help='The dir path for inference images')
parser.add_argument(
'--det_model_path',
type=str,
default='./det.nb',
help="The model path for mainbody detection")
parser.add_argument(
'--rec_model_path',
type=str,
default='./rec.nb',
help="The rec model path")
parser.add_argument(
'--rec_label_path',
type=str,
default='./label.txt',
help='The rec model label')
parser.add_argument(
'--arch',
type=str,
default='PicoDet',
help='The model structure for detection model')
parser.add_argument(
'--fpn-stride',
type=list,
default=[8, 16, 32, 64],
help="The fpn strid for detection model")
parser.add_argument(
'--keep_top_k',
type=int,
default=100,
help='The params for nms(postprocess for detection)')
parser.add_argument(
'--nms-name',
type=str,
default='MultiClassNMS',
help='The nms name for postprocess of detection model')
parser.add_argument(
'--nms_threshold',
type=float,
default=0.5,
help='The nms nms_threshold for detection postprocess')
parser.add_argument(
'--nms_top_k',
type=int,
default=1000,
help='The nms_top_k in postprocess of detection model')
parser.add_argument(
'--score_threshold',
type=float,
default=0.3,
help='The score_threshold for postprocess of detection')
args = parser.parse_args()
return args
def main():
args = parse_args()
config_yaml = yaml.safe_load(open(args.yaml_path))
config_json = {}
config_json["Global"] = {}
config_json["Global"][
"infer_imgs"] = args.img_path if args.img_path else config_yaml[
"Global"]["infer_imgs"]
if args.img_dir is not None:
config_json["Global"]["infer_imgs_dir"] = args.img_dir
config_json["Global"]["infer_imgs"] = None
else:
config_json["Global"][
"infer_imgs"] = args.img_path if args.img_path else config_yaml[
"Global"]["infer_imgs"]
config_json["Global"]["batch_size"] = config_yaml["Global"]["batch_size"]
config_json["Global"]["cpu_num_threads"] = min(
config_yaml["Global"]["cpu_num_threads"], 4)
config_json["Global"]["image_shape"] = config_yaml["Global"]["image_shape"]
config_json["Global"]["det_model_path"] = args.det_model_path
config_json["Global"]["rec_model_path"] = args.rec_model_path
config_json["Global"]["rec_label_path"] = args.rec_label_path
config_json["Global"]["label_list"] = config_yaml["Global"]["labe_list"]
config_json["Global"]["rec_nms_thresold"] = config_yaml["Global"][
"rec_nms_thresold"]
config_json["Global"]["max_det_results"] = config_yaml["Global"][
"max_det_results"]
config_json["Global"]["det_fpn_stride"] = args.fpn_stride
config_json["Global"]["det_arch"] = args.arch
config_json["Global"]["return_k"] = config_yaml["IndexProcess"]["return_k"]
# config_json["DetPreProcess"] = config_yaml["DetPreProcess"]
config_json["DetPreProcess"] = {}
config_json["DetPreProcess"]["transform_ops"] = []
for x in config_yaml["DetPreProcess"]["transform_ops"]:
k = list(x.keys())[0]
y = x[k]
y['type'] = k
config_json["DetPreProcess"]["transform_ops"].append(y)
config_json["DetPostProcess"] = {
"keep_top_k": args.keep_top_k,
"name": args.nms_name,
"nms_threshold": args.nms_threshold,
"nms_top_k": args.nms_top_k,
"score_threshold": args.score_threshold
}
# config_json["RecPreProcess"] = config_yaml["RecPreProcess"]
config_json["RecPreProcess"] = {}
config_json["RecPreProcess"]["transform_ops"] = []
for x in config_yaml["RecPreProcess"]["transform_ops"]:
k = list(x.keys())[0]
y = x[k]
if y is not None:
y["type"] = k
config_json["RecPreProcess"]["transform_ops"].append(y)
with open('shitu_config.json', 'w') as fd:
json.dump(config_json, fd, indent=4)
if __name__ == '__main__':
main()