mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
Hbsun/end2end demo (#105)
* add end2end demo * fix typo * pad box * fix bug of crnn * fix polygon * update docstring * fix bug of polygon * updare demo api * fix except * rename * fix with comments
This commit is contained in:
parent
ce98d23ba2
commit
43dcb32d4f
121
demo/ocr_image_demo.py
Normal file
121
demo/ocr_image_demo.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import json
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
|
||||||
|
from mmdet.apis import init_detector
|
||||||
|
from mmocr.apis.inference import model_inference
|
||||||
|
from mmocr.core.visualize import det_recog_show_result
|
||||||
|
from mmocr.datasets.pipelines.crop import crop_img
|
||||||
|
|
||||||
|
|
||||||
|
def write_json(obj, fpath):
|
||||||
|
"""Write json object to file."""
|
||||||
|
with open(fpath, 'w') as f:
|
||||||
|
json.dump(obj, f, indent=4, separators=(',', ': '), ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def det_and_recog_inference(args, det_model, recog_model):
|
||||||
|
image_path = args.img
|
||||||
|
end2end_res = {'filename': image_path}
|
||||||
|
end2end_res['result'] = []
|
||||||
|
|
||||||
|
image = mmcv.imread(image_path)
|
||||||
|
det_result = model_inference(det_model, image)
|
||||||
|
bboxes = det_result['boundary_result']
|
||||||
|
|
||||||
|
for bbox in bboxes:
|
||||||
|
box_res = {}
|
||||||
|
box_res['box'] = [round(x) for x in bbox[:-1]]
|
||||||
|
box_res['box_score'] = float(bbox[-1])
|
||||||
|
box = bbox[:8]
|
||||||
|
if len(bbox) > 9:
|
||||||
|
min_x = min(bbox[0:-1:2])
|
||||||
|
min_y = min(bbox[1:-1:2])
|
||||||
|
max_x = max(bbox[0:-1:2])
|
||||||
|
max_y = max(bbox[1:-1:2])
|
||||||
|
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
|
||||||
|
box_img = crop_img(image, box)
|
||||||
|
|
||||||
|
recog_result = model_inference(recog_model, box_img)
|
||||||
|
|
||||||
|
text = recog_result['text']
|
||||||
|
text_score = recog_result['score']
|
||||||
|
if isinstance(text_score, list):
|
||||||
|
text_score = sum(text_score) / max(1, len(text))
|
||||||
|
box_res['text'] = text
|
||||||
|
box_res['text_score'] = text_score
|
||||||
|
end2end_res['result'].append(box_res)
|
||||||
|
|
||||||
|
return end2end_res
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('img', type=str, help='Input Image file.')
|
||||||
|
parser.add_argument(
|
||||||
|
'out_file', type=str, help='Output file name of the visualized image.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--det-config',
|
||||||
|
type=str,
|
||||||
|
default='./configs/textdet/psenet/'
|
||||||
|
'psenet_r50_fpnf_600e_icdar2015.py',
|
||||||
|
help='Text detection config file.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--det-ckpt',
|
||||||
|
type=str,
|
||||||
|
default='https://download.openmmlab.com/'
|
||||||
|
'mmocr/textdet/psenet/'
|
||||||
|
'psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth',
|
||||||
|
help='Text detection checkpint file (local or url).')
|
||||||
|
parser.add_argument(
|
||||||
|
'--recog-config',
|
||||||
|
type=str,
|
||||||
|
default='./configs/textrecog/sar/'
|
||||||
|
'sar_r31_parallel_decoder_academic.py',
|
||||||
|
help='Text recognition config file.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--recog-ckpt',
|
||||||
|
type=str,
|
||||||
|
default='https://download.openmmlab.com/'
|
||||||
|
'mmocr/textrecog/sar/'
|
||||||
|
'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
|
||||||
|
help='Text recognition checkpint file (local or url).')
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', default='cuda:0', help='Device used for inference.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--imshow',
|
||||||
|
action='store_true',
|
||||||
|
help='Whether show image with OpenCV.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# build detect model
|
||||||
|
detect_model = init_detector(
|
||||||
|
args.det_config, args.det_ckpt, device=args.device)
|
||||||
|
if hasattr(detect_model, 'module'):
|
||||||
|
detect_model = detect_model.module
|
||||||
|
if detect_model.cfg.data.test['type'] == 'ConcatDataset':
|
||||||
|
detect_model.cfg.data.test.pipeline = \
|
||||||
|
detect_model.cfg.data.test['datasets'][0].pipeline
|
||||||
|
|
||||||
|
# build recog model
|
||||||
|
recog_model = init_detector(
|
||||||
|
args.recog_config, args.recog_ckpt, device=args.device)
|
||||||
|
if hasattr(recog_model, 'module'):
|
||||||
|
recog_model = recog_model.module
|
||||||
|
if recog_model.cfg.data.test['type'] == 'ConcatDataset':
|
||||||
|
recog_model.cfg.data.test.pipeline = \
|
||||||
|
recog_model.cfg.data.test['datasets'][0].pipeline
|
||||||
|
|
||||||
|
det_recog_result = det_and_recog_inference(args, detect_model, recog_model)
|
||||||
|
print(f'result: {det_recog_result}')
|
||||||
|
write_json(det_recog_result, args.out_file + '.json')
|
||||||
|
|
||||||
|
img = det_recog_show_result(args.img, det_recog_result)
|
||||||
|
mmcv.imwrite(img, args.out_file)
|
||||||
|
if args.imshow:
|
||||||
|
mmcv.imshow(img, 'predicted results')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -24,6 +24,14 @@ python demo/image_demo.py demo/demo_text_det.jpg configs/xxx.py xxx.pth demo/dem
|
|||||||
|
|
||||||
The predicted result will be saved as `demo/demo_text_det_pred.jpg`.
|
The predicted result will be saved as `demo/demo_text_det_pred.jpg`.
|
||||||
|
|
||||||
|
To end-to-end test a single image with both text detection and recognition,
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python demo/ocr_image_demo.py demo/demo_text_det.jpg demo/output.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
The predicted result will be saved as `demo/output.jpg`.
|
||||||
|
|
||||||
### Test Multiple Images
|
### Test Multiple Images
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -26,9 +26,9 @@ def model_inference(model, img):
|
|||||||
if isinstance(img, np.ndarray):
|
if isinstance(img, np.ndarray):
|
||||||
cfg = cfg.copy()
|
cfg = cfg.copy()
|
||||||
# set loading pipeline type
|
# set loading pipeline type
|
||||||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
|
||||||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
|
||||||
|
|
||||||
|
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
||||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||||
|
|
||||||
if isinstance(img, np.ndarray):
|
if isinstance(img, np.ndarray):
|
||||||
@ -51,14 +51,16 @@ def model_inference(model, img):
|
|||||||
img_metas.data[0] for img_metas in data['img_metas']
|
img_metas.data[0] for img_metas in data['img_metas']
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
data['img_metas'] = data['img_metas'].data[0]
|
data['img_metas'] = data['img_metas'].data
|
||||||
|
|
||||||
# process img
|
# process img
|
||||||
if isinstance(img, np.ndarray):
|
if isinstance(data['img'], list):
|
||||||
data['img'] = [img.data[0] for img in data['img']]
|
data['img'] = [img.data[0] for img in data['img']]
|
||||||
for idx, img in enumerate(data['img']):
|
for idx, img in enumerate(data['img']):
|
||||||
if img.dim() == 3:
|
if img.dim() == 3:
|
||||||
data['img'][idx] = img.unsqueeze(0)
|
data['img'][idx] = img.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
data['img_metas'] = data['img_metas'][0]
|
||||||
|
|
||||||
if next(model.parameters()).is_cuda:
|
if next(model.parameters()).is_cuda:
|
||||||
# scatter to specified GPU
|
# scatter to specified GPU
|
||||||
|
@ -417,3 +417,113 @@ def imshow_edge_node(img,
|
|||||||
mmcv.imwrite(vis_img, out_file)
|
mmcv.imwrite(vis_img, out_file)
|
||||||
|
|
||||||
return vis_img
|
return vis_img
|
||||||
|
|
||||||
|
|
||||||
|
def gen_color():
|
||||||
|
"""Generate BGR color schemes."""
|
||||||
|
color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249),
|
||||||
|
(123, 151, 138), (187, 200, 178), (148, 137, 69),
|
||||||
|
(169, 200, 200), (155, 175, 131), (154, 194, 182),
|
||||||
|
(178, 190, 137), (140, 211, 222), (83, 156, 222)]
|
||||||
|
return color_list
|
||||||
|
|
||||||
|
|
||||||
|
def draw_polygons(img, polys):
|
||||||
|
"""Draw polygons on image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): The original image.
|
||||||
|
polys (list[list[float]]): Detected polygons.
|
||||||
|
Return:
|
||||||
|
out_img (np.ndarray): Visualized image.
|
||||||
|
"""
|
||||||
|
dst_img = img.copy()
|
||||||
|
color_list = gen_color()
|
||||||
|
out_img = dst_img
|
||||||
|
for idx, poly in enumerate(polys):
|
||||||
|
poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32)
|
||||||
|
cv2.drawContours(
|
||||||
|
img,
|
||||||
|
np.array([poly]),
|
||||||
|
-1,
|
||||||
|
color_list[idx % len(color_list)],
|
||||||
|
thickness=cv2.FILLED)
|
||||||
|
out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0)
|
||||||
|
return out_img
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimal_font_scale(text, width):
|
||||||
|
"""Get optimal font scale for cv2.putText.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Text in one box.
|
||||||
|
width (int): The box width.
|
||||||
|
"""
|
||||||
|
for scale in reversed(range(0, 60, 1)):
|
||||||
|
textSize = cv2.getTextSize(
|
||||||
|
text,
|
||||||
|
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
fontScale=scale / 10,
|
||||||
|
thickness=1)
|
||||||
|
new_width = textSize[0][0]
|
||||||
|
if (new_width <= width):
|
||||||
|
return scale / 10
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def draw_texts(img, boxes, texts):
|
||||||
|
"""Draw boxes and texts on empty img.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): The original image.
|
||||||
|
boxes (list[list[float]]): Detected bounding boxes.
|
||||||
|
texts (list[str]): Recognized texts.
|
||||||
|
Return:
|
||||||
|
out_img (np.ndarray): Visualized image.
|
||||||
|
"""
|
||||||
|
color_list = gen_color()
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
out_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||||
|
for idx, (box, text) in enumerate(zip(boxes, texts)):
|
||||||
|
new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])]
|
||||||
|
Pts = np.array([new_box], np.int32)
|
||||||
|
cv2.polylines(
|
||||||
|
out_img, [Pts.reshape((-1, 1, 2))],
|
||||||
|
True,
|
||||||
|
color=color_list[idx % len(color_list)],
|
||||||
|
thickness=1)
|
||||||
|
min_x = int(min(box[0::2]))
|
||||||
|
max_y = int(
|
||||||
|
np.mean(np.array(box[1::2])) + 0.2 *
|
||||||
|
(max(box[1::2]) - min(box[1::2])))
|
||||||
|
font_scale = get_optimal_font_scale(
|
||||||
|
text, int(max(box[0::2]) - min(box[0::2])))
|
||||||
|
cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
font_scale, (0, 0, 0), 1)
|
||||||
|
|
||||||
|
return out_img
|
||||||
|
|
||||||
|
|
||||||
|
def det_recog_show_result(img, end2end_res):
|
||||||
|
"""Draw `result`(boxes and texts) on `img`.
|
||||||
|
Args:
|
||||||
|
img (str or np.ndarray): The image to be displayed.
|
||||||
|
end2end_res (dict): Text detect and recognize results.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
out_img (np.ndarray): Visualized image.
|
||||||
|
"""
|
||||||
|
img = mmcv.imread(img)
|
||||||
|
boxes, texts = [], []
|
||||||
|
for res in end2end_res['result']:
|
||||||
|
boxes.append(res['box'])
|
||||||
|
texts.append(res['text'])
|
||||||
|
box_vis_img = draw_polygons(img, boxes)
|
||||||
|
text_vis_img = draw_texts(img, boxes, texts)
|
||||||
|
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
out_img = np.ones((h, w * 2, 3), dtype=np.uint8)
|
||||||
|
out_img[:, :w, :] = box_vis_img
|
||||||
|
out_img[:, w:, :] = text_vis_img
|
||||||
|
|
||||||
|
return out_img
|
||||||
|
@ -2,7 +2,7 @@ from .box_utils import sort_vertex
|
|||||||
from .custom_format_bundle import CustomFormatBundle
|
from .custom_format_bundle import CustomFormatBundle
|
||||||
from .dbnet_transforms import EastRandomCrop, ImgAug
|
from .dbnet_transforms import EastRandomCrop, ImgAug
|
||||||
from .kie_transforms import KIEFormatBundle
|
from .kie_transforms import KIEFormatBundle
|
||||||
from .loading import LoadTextAnnotations
|
from .loading import LoadImageFromNdarray, LoadTextAnnotations
|
||||||
from .ocr_seg_targets import OCRSegTargets
|
from .ocr_seg_targets import OCRSegTargets
|
||||||
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
||||||
OpencvToPil, PilToOpencv, RandomPaddingOCR,
|
OpencvToPil, PilToOpencv, RandomPaddingOCR,
|
||||||
@ -22,5 +22,5 @@ __all__ = [
|
|||||||
'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
|
'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
|
||||||
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
|
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
|
||||||
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
|
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
|
||||||
'sort_vertex'
|
'sort_vertex', 'LoadImageFromNdarray'
|
||||||
]
|
]
|
||||||
|
@ -83,24 +83,41 @@ def warp_img(src_img,
|
|||||||
return dst_img
|
return dst_img
|
||||||
|
|
||||||
|
|
||||||
def crop_img(src_img, box):
|
def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
|
||||||
"""Crop box area to rectangle.
|
"""Crop text region with their bounding box.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src_img (np.array): Image before crop.
|
src_img (np.array): The original image.
|
||||||
box (list[float | int]): Points of quadrangle.
|
box (list[float | int]): Points of quadrangle.
|
||||||
|
long_edge_pad_ratio (float): Box pad ratio for long edge
|
||||||
|
corresponding to font size.
|
||||||
|
short_edge_pad_ratio (float): Box pad ratio for short edge
|
||||||
|
corresponding to font size.
|
||||||
"""
|
"""
|
||||||
assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
|
assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
|
||||||
assert len(box) == 8
|
assert len(box) == 8
|
||||||
|
assert 0. <= long_edge_pad_ratio < 1.0
|
||||||
|
assert 0. <= short_edge_pad_ratio < 1.0
|
||||||
|
|
||||||
h, w = src_img.shape[:2]
|
h, w = src_img.shape[:2]
|
||||||
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
|
points_x = np.clip(np.array(box[0::2]), 0, w)
|
||||||
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
|
points_y = np.clip(np.array(box[1::2]), 0, h)
|
||||||
|
|
||||||
left = int(min(points_x))
|
box_width = np.max(points_x) - np.min(points_x)
|
||||||
top = int(min(points_y))
|
box_height = np.max(points_y) - np.min(points_y)
|
||||||
right = int(max(points_x))
|
font_size = min(box_height, box_width)
|
||||||
bottom = int(max(points_y))
|
|
||||||
|
if box_height < box_width:
|
||||||
|
horizontal_pad = long_edge_pad_ratio * font_size
|
||||||
|
vertical_pad = short_edge_pad_ratio * font_size
|
||||||
|
else:
|
||||||
|
horizontal_pad = short_edge_pad_ratio * font_size
|
||||||
|
vertical_pad = long_edge_pad_ratio * font_size
|
||||||
|
|
||||||
|
left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
|
||||||
|
top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)
|
||||||
|
right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w)
|
||||||
|
bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h)
|
||||||
|
|
||||||
dst_img = src_img[top:bottom, left:right]
|
dst_img = src_img[top:bottom, left:right]
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mmdet.core import BitmapMasks, PolygonMasks
|
from mmdet.core import BitmapMasks, PolygonMasks
|
||||||
from mmdet.datasets.builder import PIPELINES
|
from mmdet.datasets.builder import PIPELINES
|
||||||
from mmdet.datasets.pipelines.loading import LoadAnnotations
|
from mmdet.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile
|
||||||
|
|
||||||
|
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
@ -66,3 +67,40 @@ class LoadTextAnnotations(LoadAnnotations):
|
|||||||
results['gt_masks'] = gt_masks
|
results['gt_masks'] = gt_masks
|
||||||
results['mask_fields'].append('gt_masks')
|
results['mask_fields'].append('gt_masks')
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class LoadImageFromNdarray(LoadImageFromFile):
|
||||||
|
"""Load an image from np.ndarray.
|
||||||
|
|
||||||
|
Similar with :obj:`LoadImageFromFile`, but the image read from
|
||||||
|
``results['img']``, which is np.ndarray.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
"""Call functions to add image meta information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict with Webcam read image in
|
||||||
|
``results['img']``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The dict contains loaded image and meta information.
|
||||||
|
"""
|
||||||
|
assert results['img'].dtype == 'uint8'
|
||||||
|
|
||||||
|
img = results['img']
|
||||||
|
if self.color_type == 'grayscale' and img.shape[2] == 3:
|
||||||
|
img = mmcv.bgr2gray(img, keepdim=True)
|
||||||
|
if self.color_type == 'color' and img.shape[2] == 1:
|
||||||
|
img = mmcv.gray2bgr(img)
|
||||||
|
if self.to_float32:
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
|
||||||
|
results['filename'] = None
|
||||||
|
results['ori_filename'] = None
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
results['img_fields'] = ['img']
|
||||||
|
return results
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import urllib
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mmcv.image import imread
|
from mmcv.image import imread
|
||||||
@ -9,85 +7,29 @@ from mmdet.apis import init_detector
|
|||||||
from mmocr.apis.inference import model_inference
|
from mmocr.apis.inference import model_inference
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.mark.parametrize('cfg_file', [
|
||||||
def project_dir():
|
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||||
return os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
'../configs/textrecog/crnn/crnn_academic_dataset.py',
|
||||||
|
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||||
|
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
|
||||||
@pytest.fixture
|
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
|
||||||
def sample_img_path(project_dir):
|
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
|
||||||
return os.path.join(project_dir, '../demo/demo_text_recog.jpg')
|
])
|
||||||
|
def test_model_inference(cfg_file):
|
||||||
|
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||||
@pytest.fixture
|
config_file = os.path.join(tmp_dir, cfg_file)
|
||||||
def sample_det_img_path(project_dir):
|
|
||||||
return os.path.join(project_dir, '../demo/demo_text_det.jpg')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sarnet_model(project_dir):
|
|
||||||
print(project_dir)
|
|
||||||
config_file = os.path.join(
|
|
||||||
project_dir,
|
|
||||||
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py')
|
|
||||||
checkpoint_file = os.path.join(
|
|
||||||
project_dir,
|
|
||||||
'../checkpoints/sar_r31_parallel_decoder_academic-dba3a4a3.pth')
|
|
||||||
|
|
||||||
if not os.path.exists(checkpoint_file):
|
|
||||||
url = ('https://download.openmmlab.com/mmocr'
|
|
||||||
'/textrecog/sar/'
|
|
||||||
'sar_r31_parallel_decoder_academic-dba3a4a3.pth')
|
|
||||||
print(f'Downloading {url} ...')
|
|
||||||
local_filename, _ = urllib.request.urlretrieve(url)
|
|
||||||
os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
|
|
||||||
shutil.move(local_filename, checkpoint_file)
|
|
||||||
print(f'Saved as {checkpoint_file}')
|
|
||||||
else:
|
|
||||||
print(f'Using existing checkpoint {checkpoint_file}')
|
|
||||||
|
|
||||||
device = 'cpu'
|
|
||||||
model = init_detector(
|
|
||||||
config_file, checkpoint=checkpoint_file, device=device)
|
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
||||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
|
||||||
0].pipeline
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def psenet_model(project_dir):
|
|
||||||
config_file = os.path.join(
|
|
||||||
project_dir,
|
|
||||||
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py')
|
|
||||||
|
|
||||||
device = 'cpu'
|
device = 'cpu'
|
||||||
model = init_detector(config_file, checkpoint=None, device=device)
|
model = init_detector(config_file, checkpoint=None, device=device)
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||||
0].pipeline
|
0].pipeline
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_inference_image_path(sample_img_path, sarnet_model):
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
model_inference(sarnet_model, 1)
|
model_inference(model, 1)
|
||||||
|
|
||||||
model_inference(sarnet_model, sample_img_path)
|
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
|
||||||
|
model_inference(model, sample_img_path)
|
||||||
|
|
||||||
|
# numpy inference
|
||||||
def test_model_inference_image_path_det(sample_det_img_path, psenet_model):
|
|
||||||
model_inference(psenet_model, sample_det_img_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_inference_numpy_ndarray(sample_img_path, sarnet_model):
|
|
||||||
img = imread(sample_img_path)
|
img = imread(sample_img_path)
|
||||||
model_inference(sarnet_model, img)
|
|
||||||
|
|
||||||
|
model_inference(model, img)
|
||||||
def test_model_inference_numpy_ndarray_det(sample_det_img_path, psenet_model):
|
|
||||||
det_img = imread(sample_det_img_path)
|
|
||||||
model_inference(psenet_model, det_img)
|
|
||||||
|
20
tests/test_core/test_end2end_vis.py
Normal file
20
tests/test_core/test_end2end_vis.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmocr.core import det_recog_show_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_det_recog_show_result():
|
||||||
|
img = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||||
|
det_recog_res = {
|
||||||
|
'result': [{
|
||||||
|
'box': [51, 88, 51, 62, 85, 62, 85, 88],
|
||||||
|
'box_score': 0.9417,
|
||||||
|
'text': 'hell',
|
||||||
|
'text_score': 0.8834
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
vis_img = det_recog_show_result(img, det_recog_res)
|
||||||
|
assert vis_img.shape[0] == 100
|
||||||
|
assert vis_img.shape[1] == 200
|
||||||
|
assert vis_img.shape[2] == 3
|
@ -85,12 +85,21 @@ def test_min_rect_crop():
|
|||||||
dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
|
dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
|
||||||
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
|
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
|
||||||
|
|
||||||
cropped_img = crop_img(dummy_img, dummy_box)
|
cropped_img = crop_img(
|
||||||
|
dummy_img,
|
||||||
|
dummy_box,
|
||||||
|
0.,
|
||||||
|
0.,
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
crop_img(dummy_img, [])
|
crop_img(dummy_img, [])
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
crop_img(dummy_img, [20, 40, 40, 20])
|
crop_img(dummy_img, [20, 40, 40, 20])
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
crop_img(dummy_img, dummy_box, 4, 0.2)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
crop_img(dummy_img, dummy_box, 0.4, 1.2)
|
||||||
|
|
||||||
assert math.isclose(cropped_img.shape[0], 20)
|
assert math.isclose(cropped_img.shape[0], 20)
|
||||||
assert math.isclose(cropped_img.shape[1], 100)
|
assert math.isclose(cropped_img.shape[1], 100)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mmocr.datasets.pipelines import LoadTextAnnotations
|
from mmocr.datasets.pipelines import LoadImageFromNdarray, LoadTextAnnotations
|
||||||
|
|
||||||
|
|
||||||
def _create_dummy_ann():
|
def _create_dummy_ann():
|
||||||
@ -36,3 +36,13 @@ def test_loadtextannotation():
|
|||||||
assert len(output['gt_masks_ignore']) == 4
|
assert len(output['gt_masks_ignore']) == 4
|
||||||
assert np.allclose(output['gt_masks_ignore'].masks[0],
|
assert np.allclose(output['gt_masks_ignore'].masks[0],
|
||||||
[[499, 94, 531, 94, 531, 124, 499, 124]])
|
[[499, 94, 531, 94, 531, 124, 499, 124]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_img_from_numpy():
|
||||||
|
result = {'img': np.ones((32, 100, 3), dtype=np.uint8)}
|
||||||
|
|
||||||
|
load = LoadImageFromNdarray(color_type='color')
|
||||||
|
output = load(result)
|
||||||
|
|
||||||
|
assert output['img'].shape[2] == 3
|
||||||
|
assert len(output['img'].shape) == 3
|
||||||
|
Loading…
x
Reference in New Issue
Block a user