add polygon params

pull/7840/head
LDOUBLEV 2022-10-09 09:58:52 +08:00
parent 059349ab74
commit 81e4c3d821
2 changed files with 8 additions and 1 deletions

View File

@ -67,6 +67,7 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["use_polygon"] = args.det_use_polygon
elif self.det_algorithm == "DB++":
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
@ -75,6 +76,7 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["use_polygon"] = args.det_use_polygon
pre_process_list[1] = {
'NormalizeImage': {
'std': [1.0, 1.0, 1.0],
@ -204,6 +206,8 @@ class TextDetector(object):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if type(box) is list:
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
@ -262,13 +266,15 @@ class TextDetector(object):
else:
raise NotImplementedError
#self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
print("det_boxes", dt_boxes)
if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
self.det_algorithm in ["PSE", "FCE", "CT"] and
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
elif 'DB' in self.det_algorithm and self.postprocess_op.use_polygon is True:
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)

View File

@ -58,6 +58,7 @@ def init_args():
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=str2bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
parser.add_argument("--det_use_polygon", type=str2bool, default=False)
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)