mirror of https://github.com/open-mmlab/mmocr.git
Hbsun/end2end demo (#105)
* add end2end demo * fix typo * pad box * fix bug of crnn * fix polygon * update docstring * fix bug of polygon * updare demo api * fix except * rename * fix with commentspull/120/head
parent
ce98d23ba2
commit
43dcb32d4f
|
@ -0,0 +1,121 @@
|
|||
import json
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmdet.apis import init_detector
|
||||
from mmocr.apis.inference import model_inference
|
||||
from mmocr.core.visualize import det_recog_show_result
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
|
||||
|
||||
def write_json(obj, fpath):
|
||||
"""Write json object to file."""
|
||||
with open(fpath, 'w') as f:
|
||||
json.dump(obj, f, indent=4, separators=(',', ': '), ensure_ascii=False)
|
||||
|
||||
|
||||
def det_and_recog_inference(args, det_model, recog_model):
|
||||
image_path = args.img
|
||||
end2end_res = {'filename': image_path}
|
||||
end2end_res['result'] = []
|
||||
|
||||
image = mmcv.imread(image_path)
|
||||
det_result = model_inference(det_model, image)
|
||||
bboxes = det_result['boundary_result']
|
||||
|
||||
for bbox in bboxes:
|
||||
box_res = {}
|
||||
box_res['box'] = [round(x) for x in bbox[:-1]]
|
||||
box_res['box_score'] = float(bbox[-1])
|
||||
box = bbox[:8]
|
||||
if len(bbox) > 9:
|
||||
min_x = min(bbox[0:-1:2])
|
||||
min_y = min(bbox[1:-1:2])
|
||||
max_x = max(bbox[0:-1:2])
|
||||
max_y = max(bbox[1:-1:2])
|
||||
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
|
||||
box_img = crop_img(image, box)
|
||||
|
||||
recog_result = model_inference(recog_model, box_img)
|
||||
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, list):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
box_res['text'] = text
|
||||
box_res['text_score'] = text_score
|
||||
end2end_res['result'].append(box_res)
|
||||
|
||||
return end2end_res
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('img', type=str, help='Input Image file.')
|
||||
parser.add_argument(
|
||||
'out_file', type=str, help='Output file name of the visualized image.')
|
||||
parser.add_argument(
|
||||
'--det-config',
|
||||
type=str,
|
||||
default='./configs/textdet/psenet/'
|
||||
'psenet_r50_fpnf_600e_icdar2015.py',
|
||||
help='Text detection config file.')
|
||||
parser.add_argument(
|
||||
'--det-ckpt',
|
||||
type=str,
|
||||
default='https://download.openmmlab.com/'
|
||||
'mmocr/textdet/psenet/'
|
||||
'psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth',
|
||||
help='Text detection checkpint file (local or url).')
|
||||
parser.add_argument(
|
||||
'--recog-config',
|
||||
type=str,
|
||||
default='./configs/textrecog/sar/'
|
||||
'sar_r31_parallel_decoder_academic.py',
|
||||
help='Text recognition config file.')
|
||||
parser.add_argument(
|
||||
'--recog-ckpt',
|
||||
type=str,
|
||||
default='https://download.openmmlab.com/'
|
||||
'mmocr/textrecog/sar/'
|
||||
'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
|
||||
help='Text recognition checkpint file (local or url).')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
parser.add_argument(
|
||||
'--imshow',
|
||||
action='store_true',
|
||||
help='Whether show image with OpenCV.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# build detect model
|
||||
detect_model = init_detector(
|
||||
args.det_config, args.det_ckpt, device=args.device)
|
||||
if hasattr(detect_model, 'module'):
|
||||
detect_model = detect_model.module
|
||||
if detect_model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
detect_model.cfg.data.test.pipeline = \
|
||||
detect_model.cfg.data.test['datasets'][0].pipeline
|
||||
|
||||
# build recog model
|
||||
recog_model = init_detector(
|
||||
args.recog_config, args.recog_ckpt, device=args.device)
|
||||
if hasattr(recog_model, 'module'):
|
||||
recog_model = recog_model.module
|
||||
if recog_model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
recog_model.cfg.data.test.pipeline = \
|
||||
recog_model.cfg.data.test['datasets'][0].pipeline
|
||||
|
||||
det_recog_result = det_and_recog_inference(args, detect_model, recog_model)
|
||||
print(f'result: {det_recog_result}')
|
||||
write_json(det_recog_result, args.out_file + '.json')
|
||||
|
||||
img = det_recog_show_result(args.img, det_recog_result)
|
||||
mmcv.imwrite(img, args.out_file)
|
||||
if args.imshow:
|
||||
mmcv.imshow(img, 'predicted results')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -24,6 +24,14 @@ python demo/image_demo.py demo/demo_text_det.jpg configs/xxx.py xxx.pth demo/dem
|
|||
|
||||
The predicted result will be saved as `demo/demo_text_det_pred.jpg`.
|
||||
|
||||
To end-to-end test a single image with both text detection and recognition,
|
||||
|
||||
```shell
|
||||
python demo/ocr_image_demo.py demo/demo_text_det.jpg demo/output.jpg
|
||||
```
|
||||
|
||||
The predicted result will be saved as `demo/output.jpg`.
|
||||
|
||||
### Test Multiple Images
|
||||
|
||||
```shell
|
||||
|
|
|
@ -26,9 +26,9 @@ def model_inference(model, img):
|
|||
if isinstance(img, np.ndarray):
|
||||
cfg = cfg.copy()
|
||||
# set loading pipeline type
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
|
||||
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
|
||||
if isinstance(img, np.ndarray):
|
||||
|
@ -51,14 +51,16 @@ def model_inference(model, img):
|
|||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
else:
|
||||
data['img_metas'] = data['img_metas'].data[0]
|
||||
data['img_metas'] = data['img_metas'].data
|
||||
|
||||
# process img
|
||||
if isinstance(img, np.ndarray):
|
||||
if isinstance(data['img'], list):
|
||||
data['img'] = [img.data[0] for img in data['img']]
|
||||
for idx, img in enumerate(data['img']):
|
||||
if img.dim() == 3:
|
||||
data['img'][idx] = img.unsqueeze(0)
|
||||
for idx, img in enumerate(data['img']):
|
||||
if img.dim() == 3:
|
||||
data['img'][idx] = img.unsqueeze(0)
|
||||
else:
|
||||
data['img_metas'] = data['img_metas'][0]
|
||||
|
||||
if next(model.parameters()).is_cuda:
|
||||
# scatter to specified GPU
|
||||
|
|
|
@ -417,3 +417,113 @@ def imshow_edge_node(img,
|
|||
mmcv.imwrite(vis_img, out_file)
|
||||
|
||||
return vis_img
|
||||
|
||||
|
||||
def gen_color():
|
||||
"""Generate BGR color schemes."""
|
||||
color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249),
|
||||
(123, 151, 138), (187, 200, 178), (148, 137, 69),
|
||||
(169, 200, 200), (155, 175, 131), (154, 194, 182),
|
||||
(178, 190, 137), (140, 211, 222), (83, 156, 222)]
|
||||
return color_list
|
||||
|
||||
|
||||
def draw_polygons(img, polys):
|
||||
"""Draw polygons on image.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
polys (list[list[float]]): Detected polygons.
|
||||
Return:
|
||||
out_img (np.ndarray): Visualized image.
|
||||
"""
|
||||
dst_img = img.copy()
|
||||
color_list = gen_color()
|
||||
out_img = dst_img
|
||||
for idx, poly in enumerate(polys):
|
||||
poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32)
|
||||
cv2.drawContours(
|
||||
img,
|
||||
np.array([poly]),
|
||||
-1,
|
||||
color_list[idx % len(color_list)],
|
||||
thickness=cv2.FILLED)
|
||||
out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0)
|
||||
return out_img
|
||||
|
||||
|
||||
def get_optimal_font_scale(text, width):
|
||||
"""Get optimal font scale for cv2.putText.
|
||||
|
||||
Args:
|
||||
text (str): Text in one box.
|
||||
width (int): The box width.
|
||||
"""
|
||||
for scale in reversed(range(0, 60, 1)):
|
||||
textSize = cv2.getTextSize(
|
||||
text,
|
||||
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
||||
fontScale=scale / 10,
|
||||
thickness=1)
|
||||
new_width = textSize[0][0]
|
||||
if (new_width <= width):
|
||||
return scale / 10
|
||||
return 1
|
||||
|
||||
|
||||
def draw_texts(img, boxes, texts):
|
||||
"""Draw boxes and texts on empty img.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
boxes (list[list[float]]): Detected bounding boxes.
|
||||
texts (list[str]): Recognized texts.
|
||||
Return:
|
||||
out_img (np.ndarray): Visualized image.
|
||||
"""
|
||||
color_list = gen_color()
|
||||
h, w = img.shape[:2]
|
||||
out_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
for idx, (box, text) in enumerate(zip(boxes, texts)):
|
||||
new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])]
|
||||
Pts = np.array([new_box], np.int32)
|
||||
cv2.polylines(
|
||||
out_img, [Pts.reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=color_list[idx % len(color_list)],
|
||||
thickness=1)
|
||||
min_x = int(min(box[0::2]))
|
||||
max_y = int(
|
||||
np.mean(np.array(box[1::2])) + 0.2 *
|
||||
(max(box[1::2]) - min(box[1::2])))
|
||||
font_scale = get_optimal_font_scale(
|
||||
text, int(max(box[0::2]) - min(box[0::2])))
|
||||
cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_scale, (0, 0, 0), 1)
|
||||
|
||||
return out_img
|
||||
|
||||
|
||||
def det_recog_show_result(img, end2end_res):
|
||||
"""Draw `result`(boxes and texts) on `img`.
|
||||
Args:
|
||||
img (str or np.ndarray): The image to be displayed.
|
||||
end2end_res (dict): Text detect and recognize results.
|
||||
|
||||
Return:
|
||||
out_img (np.ndarray): Visualized image.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
boxes, texts = [], []
|
||||
for res in end2end_res['result']:
|
||||
boxes.append(res['box'])
|
||||
texts.append(res['text'])
|
||||
box_vis_img = draw_polygons(img, boxes)
|
||||
text_vis_img = draw_texts(img, boxes, texts)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
out_img = np.ones((h, w * 2, 3), dtype=np.uint8)
|
||||
out_img[:, :w, :] = box_vis_img
|
||||
out_img[:, w:, :] = text_vis_img
|
||||
|
||||
return out_img
|
||||
|
|
|
@ -2,7 +2,7 @@ from .box_utils import sort_vertex
|
|||
from .custom_format_bundle import CustomFormatBundle
|
||||
from .dbnet_transforms import EastRandomCrop, ImgAug
|
||||
from .kie_transforms import KIEFormatBundle
|
||||
from .loading import LoadTextAnnotations
|
||||
from .loading import LoadImageFromNdarray, LoadTextAnnotations
|
||||
from .ocr_seg_targets import OCRSegTargets
|
||||
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
||||
OpencvToPil, PilToOpencv, RandomPaddingOCR,
|
||||
|
@ -22,5 +22,5 @@ __all__ = [
|
|||
'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
|
||||
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
|
||||
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
|
||||
'sort_vertex'
|
||||
'sort_vertex', 'LoadImageFromNdarray'
|
||||
]
|
||||
|
|
|
@ -83,24 +83,41 @@ def warp_img(src_img,
|
|||
return dst_img
|
||||
|
||||
|
||||
def crop_img(src_img, box):
|
||||
"""Crop box area to rectangle.
|
||||
def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
|
||||
"""Crop text region with their bounding box.
|
||||
|
||||
Args:
|
||||
src_img (np.array): Image before crop.
|
||||
src_img (np.array): The original image.
|
||||
box (list[float | int]): Points of quadrangle.
|
||||
long_edge_pad_ratio (float): Box pad ratio for long edge
|
||||
corresponding to font size.
|
||||
short_edge_pad_ratio (float): Box pad ratio for short edge
|
||||
corresponding to font size.
|
||||
"""
|
||||
assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
|
||||
assert len(box) == 8
|
||||
assert 0. <= long_edge_pad_ratio < 1.0
|
||||
assert 0. <= short_edge_pad_ratio < 1.0
|
||||
|
||||
h, w = src_img.shape[:2]
|
||||
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
|
||||
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
|
||||
points_x = np.clip(np.array(box[0::2]), 0, w)
|
||||
points_y = np.clip(np.array(box[1::2]), 0, h)
|
||||
|
||||
left = int(min(points_x))
|
||||
top = int(min(points_y))
|
||||
right = int(max(points_x))
|
||||
bottom = int(max(points_y))
|
||||
box_width = np.max(points_x) - np.min(points_x)
|
||||
box_height = np.max(points_y) - np.min(points_y)
|
||||
font_size = min(box_height, box_width)
|
||||
|
||||
if box_height < box_width:
|
||||
horizontal_pad = long_edge_pad_ratio * font_size
|
||||
vertical_pad = short_edge_pad_ratio * font_size
|
||||
else:
|
||||
horizontal_pad = short_edge_pad_ratio * font_size
|
||||
vertical_pad = long_edge_pad_ratio * font_size
|
||||
|
||||
left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
|
||||
top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)
|
||||
right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w)
|
||||
bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h)
|
||||
|
||||
dst_img = src_img[top:bottom, left:right]
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines.loading import LoadAnnotations
|
||||
from mmdet.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
|
@ -66,3 +67,40 @@ class LoadTextAnnotations(LoadAnnotations):
|
|||
results['gt_masks'] = gt_masks
|
||||
results['mask_fields'].append('gt_masks')
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadImageFromNdarray(LoadImageFromFile):
|
||||
"""Load an image from np.ndarray.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image read from
|
||||
``results['img']``, which is np.ndarray.
|
||||
"""
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call functions to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
assert results['img'].dtype == 'uint8'
|
||||
|
||||
img = results['img']
|
||||
if self.color_type == 'grayscale' and img.shape[2] == 3:
|
||||
img = mmcv.bgr2gray(img, keepdim=True)
|
||||
if self.color_type == 'color' and img.shape[2] == 1:
|
||||
img = mmcv.gray2bgr(img)
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['filename'] = None
|
||||
results['ori_filename'] = None
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
results['img_fields'] = ['img']
|
||||
return results
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import os
|
||||
import shutil
|
||||
import urllib
|
||||
|
||||
import pytest
|
||||
from mmcv.image import imread
|
||||
|
@ -9,85 +7,29 @@ from mmdet.apis import init_detector
|
|||
from mmocr.apis.inference import model_inference
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_dir():
|
||||
return os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_img_path(project_dir):
|
||||
return os.path.join(project_dir, '../demo/demo_text_recog.jpg')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_det_img_path(project_dir):
|
||||
return os.path.join(project_dir, '../demo/demo_text_det.jpg')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sarnet_model(project_dir):
|
||||
print(project_dir)
|
||||
config_file = os.path.join(
|
||||
project_dir,
|
||||
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py')
|
||||
checkpoint_file = os.path.join(
|
||||
project_dir,
|
||||
'../checkpoints/sar_r31_parallel_decoder_academic-dba3a4a3.pth')
|
||||
|
||||
if not os.path.exists(checkpoint_file):
|
||||
url = ('https://download.openmmlab.com/mmocr'
|
||||
'/textrecog/sar/'
|
||||
'sar_r31_parallel_decoder_academic-dba3a4a3.pth')
|
||||
print(f'Downloading {url} ...')
|
||||
local_filename, _ = urllib.request.urlretrieve(url)
|
||||
os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
|
||||
shutil.move(local_filename, checkpoint_file)
|
||||
print(f'Saved as {checkpoint_file}')
|
||||
else:
|
||||
print(f'Using existing checkpoint {checkpoint_file}')
|
||||
|
||||
device = 'cpu'
|
||||
model = init_detector(
|
||||
config_file, checkpoint=checkpoint_file, device=device)
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def psenet_model(project_dir):
|
||||
config_file = os.path.join(
|
||||
project_dir,
|
||||
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py')
|
||||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||
'../configs/textrecog/crnn/crnn_academic_dataset.py',
|
||||
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
|
||||
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
|
||||
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
|
||||
])
|
||||
def test_model_inference(cfg_file):
|
||||
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
config_file = os.path.join(tmp_dir, cfg_file)
|
||||
device = 'cpu'
|
||||
model = init_detector(config_file, checkpoint=None, device=device)
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def test_model_inference_image_path(sample_img_path, sarnet_model):
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
model_inference(sarnet_model, 1)
|
||||
model_inference(model, 1)
|
||||
|
||||
model_inference(sarnet_model, sample_img_path)
|
||||
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
|
||||
model_inference(model, sample_img_path)
|
||||
|
||||
|
||||
def test_model_inference_image_path_det(sample_det_img_path, psenet_model):
|
||||
model_inference(psenet_model, sample_det_img_path)
|
||||
|
||||
|
||||
def test_model_inference_numpy_ndarray(sample_img_path, sarnet_model):
|
||||
# numpy inference
|
||||
img = imread(sample_img_path)
|
||||
model_inference(sarnet_model, img)
|
||||
|
||||
|
||||
def test_model_inference_numpy_ndarray_det(sample_det_img_path, psenet_model):
|
||||
det_img = imread(sample_det_img_path)
|
||||
model_inference(psenet_model, det_img)
|
||||
model_inference(model, img)
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
import numpy as np
|
||||
|
||||
from mmocr.core import det_recog_show_result
|
||||
|
||||
|
||||
def test_det_recog_show_result():
|
||||
img = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
det_recog_res = {
|
||||
'result': [{
|
||||
'box': [51, 88, 51, 62, 85, 62, 85, 88],
|
||||
'box_score': 0.9417,
|
||||
'text': 'hell',
|
||||
'text_score': 0.8834
|
||||
}]
|
||||
}
|
||||
|
||||
vis_img = det_recog_show_result(img, det_recog_res)
|
||||
assert vis_img.shape[0] == 100
|
||||
assert vis_img.shape[1] == 200
|
||||
assert vis_img.shape[2] == 3
|
|
@ -85,12 +85,21 @@ def test_min_rect_crop():
|
|||
dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
|
||||
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
|
||||
|
||||
cropped_img = crop_img(dummy_img, dummy_box)
|
||||
cropped_img = crop_img(
|
||||
dummy_img,
|
||||
dummy_box,
|
||||
0.,
|
||||
0.,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
crop_img(dummy_img, [])
|
||||
with pytest.raises(AssertionError):
|
||||
crop_img(dummy_img, [20, 40, 40, 20])
|
||||
with pytest.raises(AssertionError):
|
||||
crop_img(dummy_img, dummy_box, 4, 0.2)
|
||||
with pytest.raises(AssertionError):
|
||||
crop_img(dummy_img, dummy_box, 0.4, 1.2)
|
||||
|
||||
assert math.isclose(cropped_img.shape[0], 20)
|
||||
assert math.isclose(cropped_img.shape[1], 100)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
|
||||
from mmocr.datasets.pipelines import LoadTextAnnotations
|
||||
from mmocr.datasets.pipelines import LoadImageFromNdarray, LoadTextAnnotations
|
||||
|
||||
|
||||
def _create_dummy_ann():
|
||||
|
@ -36,3 +36,13 @@ def test_loadtextannotation():
|
|||
assert len(output['gt_masks_ignore']) == 4
|
||||
assert np.allclose(output['gt_masks_ignore'].masks[0],
|
||||
[[499, 94, 531, 94, 531, 124, 499, 124]])
|
||||
|
||||
|
||||
def test_load_img_from_numpy():
|
||||
result = {'img': np.ones((32, 100, 3), dtype=np.uint8)}
|
||||
|
||||
load = LoadImageFromNdarray(color_type='color')
|
||||
output = load(result)
|
||||
|
||||
assert output['img'].shape[2] == 3
|
||||
assert len(output['img'].shape) == 3
|
||||
|
|
Loading…
Reference in New Issue