mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
fix all minor bugs
This commit is contained in:
parent
7cacfc97d9
commit
c4720557e8
@ -117,7 +117,7 @@ class OCRService(WebService):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ocr_service = OCRService(name="ocr")
|
ocr_service = OCRService(name="ocr")
|
||||||
ocr_service.load_model_config("cls_server")
|
ocr_service.load_model_config(global_args.cls_model_dir)
|
||||||
ocr_service.init_rec()
|
ocr_service.init_rec()
|
||||||
if global_args.use_gpu:
|
if global_args.use_gpu:
|
||||||
ocr_service.prepare_server(
|
ocr_service.prepare_server(
|
||||||
|
@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
|
|||||||
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
||||||
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
||||||
print(r.json())
|
print(r.json())
|
||||||
break
|
|
||||||
|
@ -96,7 +96,7 @@ class DetService(WebService):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ocr_service = DetService(name="ocr")
|
ocr_service = DetService(name="ocr")
|
||||||
ocr_service.load_model_config("serving_server_dir")
|
ocr_service.load_model_config(global_args.det_model_dir)
|
||||||
ocr_service.init_det()
|
ocr_service.init_det()
|
||||||
if global_args.use_gpu:
|
if global_args.use_gpu:
|
||||||
ocr_service.prepare_server(
|
ocr_service.prepare_server(
|
||||||
|
@ -79,7 +79,6 @@ class TextDetectorHelper(TextDetector):
|
|||||||
class DetService(WebService):
|
class DetService(WebService):
|
||||||
def init_det(self):
|
def init_det(self):
|
||||||
self.text_detector = TextDetectorHelper(global_args)
|
self.text_detector = TextDetectorHelper(global_args)
|
||||||
print("init finish")
|
|
||||||
|
|
||||||
def preprocess(self, feed=[], fetch=[]):
|
def preprocess(self, feed=[], fetch=[]):
|
||||||
data = base64.b64decode(feed[0]["image"].encode('utf8'))
|
data = base64.b64decode(feed[0]["image"].encode('utf8'))
|
||||||
@ -96,7 +95,7 @@ class DetService(WebService):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ocr_service = DetService(name="ocr")
|
ocr_service = DetService(name="ocr")
|
||||||
ocr_service.load_model_config("serving_server_dir")
|
ocr_service.load_model_config(global_args.det_model_dir)
|
||||||
ocr_service.init_det()
|
ocr_service.init_det()
|
||||||
if global_args.use_gpu:
|
if global_args.use_gpu:
|
||||||
ocr_service.prepare_server(
|
ocr_service.prepare_server(
|
||||||
|
@ -44,17 +44,16 @@ class TextSystemHelper(TextSystem):
|
|||||||
if self.use_angle_cls:
|
if self.use_angle_cls:
|
||||||
self.clas_client = Debugger()
|
self.clas_client = Debugger()
|
||||||
self.clas_client.load_model_config(
|
self.clas_client.load_model_config(
|
||||||
"ocr_clas_server", gpu=True, profile=False)
|
global_args.cls_model_dir, gpu=True, profile=False)
|
||||||
self.text_classifier = TextClassifierHelper(args)
|
self.text_classifier = TextClassifierHelper(args)
|
||||||
self.det_client = Debugger()
|
self.det_client = Debugger()
|
||||||
self.det_client.load_model_config(
|
self.det_client.load_model_config(
|
||||||
"serving_server_dir", gpu=True, profile=False)
|
global_args.det_model_dir, gpu=True, profile=False)
|
||||||
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
|
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
|
feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
|
||||||
fetch_map = self.det_client.predict(feed, fetch)
|
fetch_map = self.det_client.predict(feed, fetch)
|
||||||
print("det fetch_map", fetch_map)
|
|
||||||
outputs = [fetch_map[x] for x in fetch]
|
outputs = [fetch_map[x] for x in fetch]
|
||||||
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
|
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
@ -90,12 +89,10 @@ class OCRService(WebService):
|
|||||||
|
|
||||||
def preprocess(self, feed=[], fetch=[]):
|
def preprocess(self, feed=[], fetch=[]):
|
||||||
# TODO: to handle batch rec images
|
# TODO: to handle batch rec images
|
||||||
print("start preprocess")
|
|
||||||
data = base64.b64decode(feed[0]["image"].encode('utf8'))
|
data = base64.b64decode(feed[0]["image"].encode('utf8'))
|
||||||
data = np.fromstring(data, np.uint8)
|
data = np.fromstring(data, np.uint8)
|
||||||
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||||
feed, fetch, self.tmp_args = self.text_system.preprocess(im)
|
feed, fetch, self.tmp_args = self.text_system.preprocess(im)
|
||||||
print("ocr preprocess done")
|
|
||||||
return feed, fetch
|
return feed, fetch
|
||||||
|
|
||||||
def postprocess(self, feed={}, fetch=[], fetch_map=None):
|
def postprocess(self, feed={}, fetch=[], fetch_map=None):
|
||||||
|
@ -25,7 +25,7 @@ from clas_rpc_server import TextClassifierHelper
|
|||||||
from det_rpc_server import TextDetectorHelper
|
from det_rpc_server import TextDetectorHelper
|
||||||
from rec_rpc_server import TextRecognizerHelper
|
from rec_rpc_server import TextRecognizerHelper
|
||||||
import tools.infer.utility as utility
|
import tools.infer.utility as utility
|
||||||
from tools.infer.predict_system import TextSystem
|
from tools.infer.predict_system import TextSystem, sorted_boxes
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
global_args = utility.parse_args()
|
global_args = utility.parse_args()
|
||||||
@ -48,7 +48,7 @@ class TextSystemHelper(TextSystem):
|
|||||||
self.text_classifier = TextClassifierHelper(args)
|
self.text_classifier = TextClassifierHelper(args)
|
||||||
self.det_client = Client()
|
self.det_client = Client()
|
||||||
self.det_client.load_client_config(
|
self.det_client.load_client_config(
|
||||||
"ocr_det_server/serving_client_conf.prototxt")
|
"det_db_client/serving_client_conf.prototxt")
|
||||||
self.det_client.connect(["127.0.0.1:9293"])
|
self.det_client.connect(["127.0.0.1:9293"])
|
||||||
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
|
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
|
||||||
|
|
||||||
@ -57,10 +57,10 @@ class TextSystemHelper(TextSystem):
|
|||||||
fetch_map = self.det_client.predict(feed, fetch)
|
fetch_map = self.det_client.predict(feed, fetch)
|
||||||
outputs = [fetch_map[x] for x in fetch]
|
outputs = [fetch_map[x] for x in fetch]
|
||||||
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
|
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
|
||||||
|
print(dt_boxes)
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
return None, None
|
return None, None
|
||||||
img_crop_list = []
|
img_crop_list = []
|
||||||
sorted_boxes = SortedBoxes()
|
|
||||||
dt_boxes = sorted_boxes(dt_boxes)
|
dt_boxes = sorted_boxes(dt_boxes)
|
||||||
for bno in range(len(dt_boxes)):
|
for bno in range(len(dt_boxes)):
|
||||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||||
@ -70,6 +70,7 @@ class TextSystemHelper(TextSystem):
|
|||||||
feed, fetch, self.tmp_args = self.text_classifier.preprocess(
|
feed, fetch, self.tmp_args = self.text_classifier.preprocess(
|
||||||
img_crop_list)
|
img_crop_list)
|
||||||
fetch_map = self.clas_client.predict(feed, fetch)
|
fetch_map = self.clas_client.predict(feed, fetch)
|
||||||
|
print(fetch_map)
|
||||||
outputs = [fetch_map[x] for x in self.text_classifier.fetch]
|
outputs = [fetch_map[x] for x in self.text_classifier.fetch]
|
||||||
for x in fetch_map.keys():
|
for x in fetch_map.keys():
|
||||||
if ".lod" in x:
|
if ".lod" in x:
|
||||||
|
@ -36,8 +36,5 @@ for img_file in os.listdir(test_img_dir):
|
|||||||
image = cv2_to_base64(image_data1)
|
image = cv2_to_base64(image_data1)
|
||||||
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
||||||
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
||||||
print(r)
|
|
||||||
rjson = r.json()
|
rjson = r.json()
|
||||||
print(rjson)
|
print(rjson)
|
||||||
#for x in rjson["result"]["pred_text"]:
|
|
||||||
# print(x)
|
|
||||||
|
@ -85,7 +85,6 @@ class TextRecognizerHelper(TextRecognizer):
|
|||||||
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
|
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
|
||||||
predict_lod = args["softmax_0.tmp_0.lod"]
|
predict_lod = args["softmax_0.tmp_0.lod"]
|
||||||
indices = args["indices"]
|
indices = args["indices"]
|
||||||
print("indices", indices, rec_idx_lod)
|
|
||||||
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
|
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
|
||||||
for rno in range(len(rec_idx_lod) - 1):
|
for rno in range(len(rec_idx_lod) - 1):
|
||||||
beg = rec_idx_lod[rno]
|
beg = rec_idx_lod[rno]
|
||||||
@ -155,7 +154,6 @@ class OCRService(WebService):
|
|||||||
if ".lod" in x:
|
if ".lod" in x:
|
||||||
self.tmp_args[x] = fetch_map[x]
|
self.tmp_args[x] = fetch_map[x]
|
||||||
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
|
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
|
||||||
print("rec_res", rec_res)
|
|
||||||
res = {
|
res = {
|
||||||
"pred_text": [x[0] for x in rec_res],
|
"pred_text": [x[0] for x in rec_res],
|
||||||
"score": [str(x[1]) for x in rec_res]
|
"score": [str(x[1]) for x in rec_res]
|
||||||
|
@ -91,7 +91,6 @@ class TextRecognizerHelper(TextRecognizer):
|
|||||||
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
|
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
|
||||||
predict_lod = args["softmax_0.tmp_0.lod"]
|
predict_lod = args["softmax_0.tmp_0.lod"]
|
||||||
indices = args["indices"]
|
indices = args["indices"]
|
||||||
print("indices", indices, rec_idx_lod)
|
|
||||||
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
|
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
|
||||||
for rno in range(len(rec_idx_lod) - 1):
|
for rno in range(len(rec_idx_lod) - 1):
|
||||||
beg = rec_idx_lod[rno]
|
beg = rec_idx_lod[rno]
|
||||||
@ -161,7 +160,6 @@ class OCRService(WebService):
|
|||||||
if ".lod" in x:
|
if ".lod" in x:
|
||||||
self.tmp_args[x] = fetch_map[x]
|
self.tmp_args[x] = fetch_map[x]
|
||||||
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
|
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
|
||||||
print("rec_res", rec_res)
|
|
||||||
res = {
|
res = {
|
||||||
"pred_text": [x[0] for x in rec_res],
|
"pred_text": [x[0] for x in rec_res],
|
||||||
"score": [str(x[1]) for x in rec_res]
|
"score": [str(x[1]) for x in rec_res]
|
||||||
|
@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
|
|||||||
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
||||||
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
||||||
print(r.json())
|
print(r.json())
|
||||||
break
|
|
||||||
|
@ -33,7 +33,7 @@ from paddle import fluid
|
|||||||
|
|
||||||
class TextClassifier(object):
|
class TextClassifier(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
if args.use_serving is False:
|
if args.use_pdserving is False:
|
||||||
self.predictor, self.input_tensor, self.output_tensors = \
|
self.predictor, self.input_tensor, self.output_tensors = \
|
||||||
utility.create_predictor(args, mode="cls")
|
utility.create_predictor(args, mode="cls")
|
||||||
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
||||||
|
@ -75,7 +75,7 @@ class TextDetector(object):
|
|||||||
else:
|
else:
|
||||||
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
if args.use_gpu is False:
|
if args.use_pdserving is False:
|
||||||
self.predictor, self.input_tensor, self.output_tensors =\
|
self.predictor, self.input_tensor, self.output_tensors =\
|
||||||
utility.create_predictor(args, mode="det")
|
utility.create_predictor(args, mode="det")
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ from ppocr.utils.character import CharacterOps
|
|||||||
|
|
||||||
class TextRecognizer(object):
|
class TextRecognizer(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
if args.use_serving is False:
|
if args.use_pdserving is False:
|
||||||
self.predictor, self.input_tensor, self.output_tensors =\
|
self.predictor, self.input_tensor, self.output_tensors =\
|
||||||
utility.create_predictor(args, mode="rec")
|
utility.create_predictor(args, mode="rec")
|
||||||
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
||||||
|
@ -161,7 +161,12 @@ def main(args):
|
|||||||
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
||||||
|
|
||||||
draw_img = draw_ocr(
|
draw_img = draw_ocr(
|
||||||
image, boxes, txts, scores, drop_score=drop_score, font_path=font_path)
|
image,
|
||||||
|
boxes,
|
||||||
|
txts,
|
||||||
|
scores,
|
||||||
|
drop_score=drop_score,
|
||||||
|
font_path=font_path)
|
||||||
draw_img_save = "./inference_results/"
|
draw_img_save = "./inference_results/"
|
||||||
if not os.path.exists(draw_img_save):
|
if not os.path.exists(draw_img_save):
|
||||||
os.makedirs(draw_img_save)
|
os.makedirs(draw_img_save)
|
||||||
|
@ -37,7 +37,7 @@ def parse_args():
|
|||||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||||
parser.add_argument("--use_serving", type=str2bool, default=False)
|
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||||
|
|
||||||
# params for text detector
|
# params for text detector
|
||||||
parser.add_argument("--image_dir", type=str)
|
parser.add_argument("--image_dir", type=str)
|
||||||
@ -73,9 +73,7 @@ def parse_args():
|
|||||||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||||
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vis_font_path",
|
"--vis_font_path", type=str, default="./doc/simfang.ttf")
|
||||||
type=str,
|
|
||||||
default="./doc/simfang.ttf")
|
|
||||||
|
|
||||||
# params for text classifier
|
# params for text classifier
|
||||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||||
@ -230,8 +228,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
|
|||||||
1])**2)
|
1])**2)
|
||||||
if box_height > 2 * box_width:
|
if box_height > 2 * box_width:
|
||||||
font_size = max(int(box_width * 0.9), 10)
|
font_size = max(int(box_width * 0.9), 10)
|
||||||
font = ImageFont.truetype(
|
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||||
font_path, font_size, encoding="utf-8")
|
|
||||||
cur_y = box[0][1]
|
cur_y = box[0][1]
|
||||||
for c in txt:
|
for c in txt:
|
||||||
char_size = font.getsize(c)
|
char_size = font.getsize(c)
|
||||||
@ -240,8 +237,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
|
|||||||
cur_y += char_size[1]
|
cur_y += char_size[1]
|
||||||
else:
|
else:
|
||||||
font_size = max(int(box_height * 0.8), 10)
|
font_size = max(int(box_height * 0.8), 10)
|
||||||
font = ImageFont.truetype(
|
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||||
font_path, font_size, encoding="utf-8")
|
|
||||||
draw_right.text(
|
draw_right.text(
|
||||||
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
|
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
|
||||||
img_left = Image.blend(image, img_left, 0.5)
|
img_left = Image.blend(image, img_left, 0.5)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user