mirror of https://github.com/alibaba/EasyCV.git
56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
"""
|
|
isort:skip_file
|
|
"""
|
|
import json
|
|
import os
|
|
import unittest
|
|
|
|
import cv2
|
|
import torch
|
|
|
|
from easycv.predictors.ocr import OCRDetPredictor, OCRRecPredictor, OCRClsPredictor, OCRPredictor
|
|
|
|
from easycv.utils.test_util import get_tmp_dir
|
|
from tests.ut_config import (PRETRAINED_MODEL_OCRDET, PRETRAINED_MODEL_OCRREC,
|
|
PRETRAINED_MODEL_OCRCLS, TEST_IMAGES_DIR)
|
|
|
|
|
|
class TorchOCRTest(unittest.TestCase):
|
|
|
|
def test_ocr_det(self):
|
|
predictor = OCRDetPredictor(PRETRAINED_MODEL_OCRDET)
|
|
img = cv2.imread(os.path.join(TEST_IMAGES_DIR, 'ocr_det.jpg'))
|
|
dt_boxes = predictor([img])[0]
|
|
self.assertEqual(dt_boxes['points'].shape[0], 16) # 16 boxes
|
|
|
|
def test_ocr_rec(self):
|
|
predictor = OCRRecPredictor(PRETRAINED_MODEL_OCRREC)
|
|
img = cv2.imread(os.path.join(TEST_IMAGES_DIR, 'ocr_rec.jpg'))
|
|
rec_out = predictor([img])[0]
|
|
self.assertEqual(rec_out['preds_text'][0], '韩国小馆') # 韩国小馆
|
|
self.assertGreater(rec_out['preds_text'][1],
|
|
0.9944) # 0.9944670796394348
|
|
|
|
def test_ocr_direction(self):
|
|
predictor = OCRClsPredictor(PRETRAINED_MODEL_OCRCLS)
|
|
img = cv2.imread(os.path.join(TEST_IMAGES_DIR, 'ocr_rec.jpg'))
|
|
cls_out = predictor([img])[0]
|
|
self.assertEqual(int(cls_out['class']), 0)
|
|
self.assertGreater(float(cls_out['neck'][0]), 0.9998) # 0.99987
|
|
|
|
def test_ocr_end2end(self):
|
|
predictor = OCRPredictor(
|
|
det_model_path=PRETRAINED_MODEL_OCRDET,
|
|
rec_model_path=PRETRAINED_MODEL_OCRREC,
|
|
cls_model_path=PRETRAINED_MODEL_OCRCLS,
|
|
use_angle_cls=True)
|
|
img = cv2.imread(os.path.join(TEST_IMAGES_DIR, 'ocr_det.jpg'))
|
|
res = predictor([img])
|
|
self.assertEqual(res[0]['rec_res'][0][0], '纯臻营养护发素')
|
|
self.assertGreater(res[0]['rec_res'][0][1], 0.91)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|