mirror of https://github.com/open-mmlab/mmocr.git
Remove useless
parent
d4dbad56ee
commit
de78a8839f
|
@ -10,16 +10,6 @@ mmocr/models/textrecog/recognizers/encode_decode_recognizer.py
|
|||
mmocr/datasets/pipelines/transforms.py
|
||||
mmocr/datasets/pipelines/dbnet_transforms.py
|
||||
|
||||
# will be deleted
|
||||
mmocr/models/textdet/heads/head_mixin.py
|
||||
|
||||
# They will be removed later all det models have been refactored
|
||||
mmocr/models/common/detectors/single_stage.py
|
||||
mmocr/models/textdet/detectors/text_detector_mixin.py
|
||||
|
||||
# It will be covered by tests of any det model implemented in future
|
||||
mmocr/models/textdet/detectors/single_stage_text_detector.py
|
||||
|
||||
# It will be removed after all utils are moved to mmocr.utils
|
||||
mmocr/core/evaluation/utils.py
|
||||
mmocr/models/textdet/postprocessors/utils.py
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .deploy_utils import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
|
||||
__all__ = [
|
||||
'ONNXRuntimeRecognizer', 'ONNXRuntimeDetector', 'TensorRTDetector',
|
||||
'TensorRTRecognizer'
|
||||
]
|
|
@ -1,328 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.detectors.single_stage_text_detector import \
|
||||
SingleStageTextDetector
|
||||
from mmocr.models.textdet.detectors.text_detector_mixin import \
|
||||
TextDetectorMixin
|
||||
from mmocr.models.textrecog.recognizers.encode_decode_recognizer import \
|
||||
EncodeDecodeRecognizer
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
def inference_with_session(sess, io_binding, input_name, output_names,
|
||||
input_tensor):
|
||||
device_type = input_tensor.device.type
|
||||
device_id = input_tensor.device.index
|
||||
device_id = 0 if device_id is None else device_id
|
||||
io_binding.bind_input(
|
||||
name=input_name,
|
||||
device_type=device_type,
|
||||
device_id=device_id,
|
||||
element_type=np.float32,
|
||||
shape=input_tensor.shape,
|
||||
buffer_ptr=input_tensor.data_ptr())
|
||||
for name in output_names:
|
||||
io_binding.bind_output(name)
|
||||
sess.run_with_iobinding(io_binding)
|
||||
pred = io_binding.copy_outputs_to_cpu()
|
||||
return pred
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for evaluating onnx file of detection."""
|
||||
|
||||
def __init__(self,
|
||||
onnx_file: str,
|
||||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
if 'type' in cfg.model:
|
||||
cfg.model.pop('type')
|
||||
SingleStageTextDetector.__init__(self, **(cfg.model))
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
import onnxruntime as ort
|
||||
|
||||
# get the custom op path
|
||||
ort_custom_op_path = ''
|
||||
try:
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with ONNXRuntime from source.')
|
||||
session_options = ort.SessionOptions()
|
||||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
providers = ['CPUExecutionProvider']
|
||||
options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU'
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
options.insert(0, {'device_id': device_id})
|
||||
|
||||
sess.set_providers(providers, options)
|
||||
|
||||
self.sess = sess
|
||||
self.device_id = device_id
|
||||
self.io_binding = sess.io_binding()
|
||||
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||
for name in self.output_names:
|
||||
self.io_binding.bind_output(name)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self,
|
||||
img: torch.Tensor,
|
||||
img_metas: Iterable,
|
||||
rescale: bool = False):
|
||||
onnx_pred = inference_with_session(self.sess, self.io_binding, 'input',
|
||||
self.output_names, img)
|
||||
onnx_pred = torch.from_numpy(onnx_pred[0])
|
||||
if len(img_metas) > 1:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*(onnx_pred[i].unsqueeze(0)),
|
||||
[img_metas[i]], rescale)
|
||||
for i in range(len(img_metas))
|
||||
]
|
||||
|
||||
else:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*onnx_pred, img_metas, rescale)
|
||||
]
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
||||
"""The class for evaluating onnx file of recognition."""
|
||||
|
||||
def __init__(self,
|
||||
onnx_file: str,
|
||||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
if 'type' in cfg.model:
|
||||
cfg.model.pop('type')
|
||||
EncodeDecodeRecognizer.__init__(self, **(cfg.model))
|
||||
import onnxruntime as ort
|
||||
|
||||
# get the custom op path
|
||||
ort_custom_op_path = ''
|
||||
try:
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with ONNXRuntime from source.')
|
||||
session_options = ort.SessionOptions()
|
||||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
providers = ['CPUExecutionProvider']
|
||||
options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU'
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
options.insert(0, {'device_id': device_id})
|
||||
|
||||
sess.set_providers(providers, options)
|
||||
|
||||
self.sess = sess
|
||||
self.device_id = device_id
|
||||
self.io_binding = sess.io_binding()
|
||||
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||
for name in self.output_names:
|
||||
self.io_binding.bind_output(name)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
if isinstance(imgs, list):
|
||||
for idx, each_img in enumerate(imgs):
|
||||
if each_img.dim() == 3:
|
||||
imgs[idx] = each_img.unsqueeze(0)
|
||||
imgs = imgs[0] # avoid aug_test
|
||||
img_metas = img_metas[0]
|
||||
else:
|
||||
if len(img_metas) == 1 and isinstance(img_metas[0], list):
|
||||
img_metas = img_metas[0]
|
||||
return self.simple_test(imgs, img_metas=img_metas)
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self,
|
||||
img: torch.Tensor,
|
||||
img_metas: Iterable,
|
||||
rescale: bool = False):
|
||||
"""Test function.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
onnx_pred = inference_with_session(self.sess, self.io_binding, 'input',
|
||||
self.output_names, img)
|
||||
onnx_pred = torch.from_numpy(onnx_pred[0])
|
||||
|
||||
label_indexes, label_scores = self.label_convertor.tensor2idx(
|
||||
onnx_pred, img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for string, score in zip(label_strings, label_scores):
|
||||
results.append(dict(text=string, score=score))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for evaluating TensorRT file of detection."""
|
||||
|
||||
def __init__(self,
|
||||
trt_file: str,
|
||||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
if 'type' in cfg.model:
|
||||
cfg.model.pop('type')
|
||||
SingleStageTextDetector.__init__(self, **(cfg.model))
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
|
||||
try:
|
||||
load_tensorrt_plugin()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with TensorRT from source.')
|
||||
model = TRTWrapper(
|
||||
trt_file, input_names=['input'], output_names=['output'])
|
||||
|
||||
self.model = model
|
||||
self.device_id = device_id
|
||||
self.cfg = cfg
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self,
|
||||
img: torch.Tensor,
|
||||
img_metas: Iterable,
|
||||
rescale: bool = False):
|
||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||
trt_pred = self.model({'input': img})['output']
|
||||
if len(img_metas) > 1:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*(trt_pred[i].unsqueeze(0)),
|
||||
[img_metas[i]], rescale)
|
||||
for i in range(len(img_metas))
|
||||
]
|
||||
|
||||
else:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*trt_pred, img_metas, rescale)
|
||||
]
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TensorRTRecognizer(EncodeDecodeRecognizer):
|
||||
"""The class for evaluating TensorRT file of recognition."""
|
||||
|
||||
def __init__(self,
|
||||
trt_file: str,
|
||||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
if 'type' in cfg.model:
|
||||
cfg.model.pop('type')
|
||||
EncodeDecodeRecognizer.__init__(self, **(cfg.model))
|
||||
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
|
||||
try:
|
||||
load_tensorrt_plugin()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with TensorRT from source.')
|
||||
model = TRTWrapper(
|
||||
trt_file, input_names=['input'], output_names=['output'])
|
||||
|
||||
self.model = model
|
||||
self.device_id = device_id
|
||||
self.cfg = cfg
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
if isinstance(imgs, list):
|
||||
for idx, each_img in enumerate(imgs):
|
||||
if each_img.dim() == 3:
|
||||
imgs[idx] = each_img.unsqueeze(0)
|
||||
imgs = imgs[0] # avoid aug_test
|
||||
img_metas = img_metas[0]
|
||||
else:
|
||||
if len(img_metas) == 1 and isinstance(img_metas[0], list):
|
||||
img_metas = img_metas[0]
|
||||
return self.simple_test(imgs, img_metas=img_metas)
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self,
|
||||
img: torch.Tensor,
|
||||
img_metas: Iterable,
|
||||
rescale: bool = False):
|
||||
"""Test function.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||
trt_pred = self.model({'input': img})['output']
|
||||
|
||||
label_indexes, label_scores = self.label_convertor.tensor2idx(
|
||||
trt_pred, img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for string, score in zip(label_strings, label_scores):
|
||||
results.append(dict(text=string, score=score))
|
||||
|
||||
return results
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .single_stage import SingleStageDetector
|
||||
|
||||
__all__ = ['SingleStageDetector']
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmdet.models.detectors import \
|
||||
SingleStageDetector as MMDET_SingleStageDetector
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SingleStageDetector(MMDET_SingleStageDetector):
|
||||
"""Base class for single-stage detectors.
|
||||
|
||||
Single-stage detectors directly and densely predict bounding boxes on the
|
||||
output features of the backbone+neck.
|
||||
"""
|
||||
|
||||
# TODO: Remove this class as SDGMR has been refactored
|
||||
def __init__(self,
|
||||
backbone,
|
||||
neck=None,
|
||||
bbox_head=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(MMDET_SingleStageDetector, self).__init__(init_cfg=init_cfg)
|
||||
if pretrained:
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
backbone.pretrained = pretrained
|
||||
self.backbone = MODELS.build(backbone)
|
||||
if neck is not None:
|
||||
self.neck = MODELS.build(neck)
|
||||
bbox_head.update(train_cfg=train_cfg)
|
||||
bbox_head.update(test_cfg=test_cfg)
|
||||
self.bbox_head = MODELS.build(bbox_head)
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
|
@ -2,14 +2,12 @@
|
|||
from .dbnet import DBNet
|
||||
from .drrg import DRRG
|
||||
from .fcenet import FCENet
|
||||
from .ocr_mask_rcnn import OCRMaskRCNN
|
||||
from .panet import PANet
|
||||
from .psenet import PSENet
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
from .textsnake import TextSnake
|
||||
|
||||
__all__ = [
|
||||
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
|
||||
'PANet', 'PSENet', 'TextSnake', 'FCENet', 'DRRG'
|
||||
'SingleStageTextDetector', 'DBNet', 'PANet', 'PSENet', 'TextSnake',
|
||||
'FCENet', 'DRRG'
|
||||
]
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.models.detectors import MaskRCNN
|
||||
|
||||
from mmocr.core import seg2boundary
|
||||
from mmocr.registry import MODELS
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OCRMaskRCNN(TextDetectorMixin, MaskRCNN):
|
||||
"""Mask RCNN tailored for OCR."""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
rpn_head,
|
||||
roi_head,
|
||||
train_cfg,
|
||||
test_cfg,
|
||||
neck=None,
|
||||
pretrained=None,
|
||||
text_repr_type='quad',
|
||||
show_score=False,
|
||||
init_cfg=None):
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
MaskRCNN.__init__(
|
||||
self,
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
rpn_head=rpn_head,
|
||||
roi_head=roi_head,
|
||||
train_cfg=train_cfg,
|
||||
test_cfg=test_cfg,
|
||||
pretrained=pretrained,
|
||||
init_cfg=init_cfg)
|
||||
assert text_repr_type in ['quad', 'poly']
|
||||
self.text_repr_type = text_repr_type
|
||||
|
||||
def get_boundary(self, results):
|
||||
"""Convert segmentation into text boundaries.
|
||||
|
||||
Args:
|
||||
results (tuple): The result tuple. The first element is
|
||||
segmentation while the second is its scores.
|
||||
Returns:
|
||||
dict: A result dict containing 'boundary_result'.
|
||||
"""
|
||||
|
||||
assert isinstance(results, tuple)
|
||||
|
||||
instance_num = len(results[1][0])
|
||||
boundaries = []
|
||||
for i in range(instance_num):
|
||||
seg = results[1][0][i]
|
||||
score = results[0][0][i][-1]
|
||||
boundary = seg2boundary(seg, self.text_repr_type, score)
|
||||
if boundary is not None:
|
||||
boundaries.append(boundary)
|
||||
|
||||
results = dict(boundary_result=boundaries)
|
||||
return results
|
||||
|
||||
def simple_test(self, img, img_metas, proposals=None, rescale=False):
|
||||
|
||||
results = super().simple_test(img, img_metas, proposals, rescale)
|
||||
|
||||
boundaries = self.get_boundary(results[0])
|
||||
boundaries = boundaries if isinstance(boundaries,
|
||||
list) else [boundaries]
|
||||
return boundaries
|
|
@ -1,82 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.core import imshow_pred_boundary
|
||||
|
||||
|
||||
# TODO: delete this
|
||||
class TextDetectorMixin:
|
||||
"""Base class for text detector, only to show results.
|
||||
|
||||
Args:
|
||||
show_score (bool): Whether to show text instance score.
|
||||
"""
|
||||
|
||||
def __init__(self, show_score):
|
||||
self.show_score = show_score
|
||||
|
||||
def show_result(self,
|
||||
img,
|
||||
result,
|
||||
score_thr=0.5,
|
||||
bbox_color='green',
|
||||
text_color='green',
|
||||
thickness=1,
|
||||
font_scale=0.5,
|
||||
win_name='',
|
||||
show=False,
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
"""Draw `result` over `img`.
|
||||
|
||||
Args:
|
||||
img (str or Tensor): The image to be displayed.
|
||||
result (dict): The results to draw over `img`.
|
||||
score_thr (float, optional): Minimum score of bboxes to be shown.
|
||||
Default: 0.3.
|
||||
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
|
||||
text_color (str or tuple or :obj:`Color`): Color of texts.
|
||||
thickness (int): Thickness of lines.
|
||||
font_scale (float): Font scales of texts.
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
Default: 0.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.imshow_pred_boundary`
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
boundaries = None
|
||||
labels = None
|
||||
if 'boundary_result' in result.keys():
|
||||
boundaries = result['boundary_result']
|
||||
labels = [0] * len(boundaries)
|
||||
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
show = False
|
||||
# draw bounding boxes
|
||||
if boundaries is not None:
|
||||
imshow_pred_boundary(
|
||||
img,
|
||||
boundaries,
|
||||
labels,
|
||||
score_thr=score_thr,
|
||||
boundary_color=bbox_color,
|
||||
text_color=text_color,
|
||||
thickness=thickness,
|
||||
font_scale=font_scale,
|
||||
win_name=win_name,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file,
|
||||
show_score=self.show_score)
|
||||
|
||||
if not (show or out_file):
|
||||
warnings.warn('show==False and out_file is not specified, '
|
||||
'result image will be returned')
|
||||
return img
|
|
@ -3,12 +3,11 @@ from .base_textdet_head import BaseTextDetHead
|
|||
from .db_head import DBHead
|
||||
from .drrg_head import DRRGHead
|
||||
from .fce_head import FCEHead
|
||||
from .head_mixin import HeadMixin
|
||||
from .pan_head import PANHead
|
||||
from .pse_head import PSEHead
|
||||
from .textsnake_head import TextSnakeHead
|
||||
|
||||
__all__ = [
|
||||
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'TextSnakeHead', 'DRRGHead',
|
||||
'HeadMixin', 'BaseTextDetHead'
|
||||
'BaseTextDetHead'
|
||||
]
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
# TODO: del this
|
||||
@MODELS.register_module()
|
||||
class HeadMixin:
|
||||
"""Base head class for text detection, including loss calcalation and
|
||||
postprocess.
|
||||
|
||||
Args:
|
||||
loss (dict): Config to build loss.
|
||||
postprocessor (dict): Config to build postprocessor.
|
||||
"""
|
||||
|
||||
def __init__(self, loss, postprocessor):
|
||||
assert isinstance(loss, dict)
|
||||
assert isinstance(postprocessor, dict)
|
||||
|
||||
self.loss_module = MODELS.build(loss)
|
||||
self.postprocessor = MODELS.build(postprocessor)
|
||||
|
||||
def resize_boundary(self, boundaries, scale_factor):
|
||||
"""Rescale boundaries via scale_factor.
|
||||
|
||||
Args:
|
||||
boundaries (list[list[float]]): The boundary list. Each boundary
|
||||
has :math:`2k+1` elements with :math:`k>=4`.
|
||||
scale_factor (ndarray): The scale factor of size :math:`(4,)`.
|
||||
|
||||
Returns:
|
||||
list[list[float]]: The scaled boundaries.
|
||||
"""
|
||||
assert check_argument.is_2dlist(boundaries)
|
||||
assert isinstance(scale_factor, np.ndarray)
|
||||
assert scale_factor.shape[0] == 4
|
||||
|
||||
for b in boundaries:
|
||||
sz = len(b)
|
||||
check_argument.valid_boundary(b, True)
|
||||
b[:sz -
|
||||
1] = (np.array(b[:sz - 1]) *
|
||||
(np.tile(scale_factor[:2], int(
|
||||
(sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
|
||||
return boundaries
|
||||
|
||||
def get_boundary(self, score_maps, img_metas, rescale):
|
||||
"""Compute text boundaries via post processing.
|
||||
|
||||
Args:
|
||||
score_maps (Tensor): The text score map.
|
||||
img_metas (dict): The image meta info.
|
||||
rescale (bool): Rescale boundaries to the original image resolution
|
||||
if true, and keep the score_maps resolution if false.
|
||||
|
||||
Returns:
|
||||
dict: A dict where boundary results are stored in
|
||||
``boundary_result``.
|
||||
"""
|
||||
|
||||
assert check_argument.is_type_list(img_metas, dict)
|
||||
assert isinstance(rescale, bool)
|
||||
|
||||
score_maps = score_maps.squeeze()
|
||||
boundaries = self.postprocessor(score_maps)
|
||||
|
||||
if rescale:
|
||||
boundaries = self.resize_boundary(
|
||||
boundaries,
|
||||
1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])
|
||||
|
||||
results = dict(
|
||||
boundary_result=boundaries, filename=img_metas[0]['filename'])
|
||||
|
||||
return results
|
||||
|
||||
def loss(self, pred_maps, **kwargs):
|
||||
"""Compute the loss for scene text detection.
|
||||
|
||||
Args:
|
||||
pred_maps (Tensor): The input score maps of shape
|
||||
:math:`(NxCxHxW)`.
|
||||
|
||||
Returns:
|
||||
dict: The dict for losses.
|
||||
"""
|
||||
losses = self.loss_module(pred_maps, self.downsample_ratio, **kwargs)
|
||||
|
||||
return losses
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_postprocessor import BasePostprocessor, BaseTextDetPostProcessor
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
from .db_postprocessor import DBPostprocessor
|
||||
from .drrg_postprocessor import DRRGPostprocessor
|
||||
from .fce_postprocessor import FCEPostprocessor
|
||||
|
@ -8,7 +8,7 @@ from .pse_postprocessor import PSEPostprocessor
|
|||
from .textsnake_postprocessor import TextSnakePostprocessor
|
||||
|
||||
__all__ = [
|
||||
'BasePostprocessor', 'PSEPostprocessor', 'PANPostprocessor',
|
||||
'DBPostprocessor', 'DRRGPostprocessor', 'FCEPostprocessor',
|
||||
'TextSnakePostprocessor', 'BaseTextDetPostProcessor'
|
||||
'PSEPostprocessor', 'PANPostprocessor', 'DBPostprocessor',
|
||||
'DRRGPostprocessor', 'FCEPostprocessor', 'TextSnakePostprocessor',
|
||||
'BaseTextDetPostProcessor'
|
||||
]
|
||||
|
|
|
@ -8,26 +8,6 @@ from mmocr.core import TextDetDataSample
|
|||
from mmocr.utils import boundary_iou, is_type_list, rescale_polygons
|
||||
|
||||
|
||||
class BasePostprocessor:
|
||||
"""Deprecated.
|
||||
|
||||
TODO: remove this class when all det postprocessors are
|
||||
refactored
|
||||
"""
|
||||
|
||||
def __init__(self, text_repr_type='poly'):
|
||||
assert text_repr_type in ['poly', 'quad'
|
||||
], f'Invalid text repr type {text_repr_type}'
|
||||
|
||||
self.text_repr_type = text_repr_type
|
||||
|
||||
def is_valid_instance(self, area, confidence, area_thresh,
|
||||
confidence_thresh):
|
||||
"""If the area is a valid instance."""
|
||||
|
||||
return bool(area >= area_thresh and confidence > confidence_thresh)
|
||||
|
||||
|
||||
class BaseTextDetPostProcessor:
|
||||
"""Base postprocessor for text detection models.
|
||||
|
||||
|
|
|
@ -1,110 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
from mmcv import Config
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmdet.apis import single_gpu_test
|
||||
|
||||
from mmocr.apis.inference import disable_text_recog_aug_test
|
||||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
from mmocr.datasets import build_dataloader
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMOCR test (and eval) a onnx or tensorrt model.')
|
||||
parser.add_argument('model_config', type=str, help='Config file.')
|
||||
parser.add_argument(
|
||||
'model_file', type=str, help='Input file name for evaluation.')
|
||||
parser.add_argument(
|
||||
'model_type',
|
||||
type=str,
|
||||
help='Detection or recognition model to deploy.',
|
||||
choices=['recog', 'det'])
|
||||
parser.add_argument(
|
||||
'backend',
|
||||
type=str,
|
||||
help='Which backend to test, TensorRT or ONNXRuntime.',
|
||||
choices=['TensorRT', 'ONNXRuntime'])
|
||||
parser.add_argument(
|
||||
'--eval',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='The evaluation metrics, which depends on the dataset, e.g.,'
|
||||
'"bbox", "seg", "proposal" for COCO, and "mAP", "recall" for'
|
||||
'PASCAL VOC.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Following strings of text style are from colorama package
|
||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
||||
white_background = '\x1b[107m'
|
||||
|
||||
msg = white_background + bright_style + red_text
|
||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
||||
msg += reset_style
|
||||
warnings.warn(msg)
|
||||
|
||||
if args.device == 'cpu':
|
||||
args.device = None
|
||||
|
||||
cfg = Config.fromfile(args.model_config)
|
||||
|
||||
# build the model
|
||||
if args.model_type == 'det':
|
||||
if args.backend == 'TensorRT':
|
||||
model = TensorRTDetector(args.model_file, cfg, 0)
|
||||
else:
|
||||
model = ONNXRuntimeDetector(args.model_file, cfg, 0)
|
||||
else:
|
||||
if args.backend == 'TensorRT':
|
||||
model = TensorRTRecognizer(args.model_file, cfg, 0)
|
||||
else:
|
||||
model = ONNXRuntimeRecognizer(args.model_file, cfg, 0)
|
||||
|
||||
# build the dataloader
|
||||
samples_per_gpu = 1
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
dataset = DATASETS.build(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=False,
|
||||
shuffle=False)
|
||||
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
outputs = single_gpu_test(model, data_loader)
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
kwargs = {}
|
||||
if args.eval:
|
||||
eval_kwargs = cfg.get('evaluation', {}).copy()
|
||||
# hard-code way to remove EvalHook args
|
||||
for key in [
|
||||
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
|
||||
'rule'
|
||||
]:
|
||||
eval_kwargs.pop(key, None)
|
||||
eval_kwargs.update(dict(metric=args.eval, **kwargs))
|
||||
print(dataset.evaluate(outputs, **eval_kwargs))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,110 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import mmcv
|
||||
|
||||
try:
|
||||
from model_archiver.model_packaging import package_model
|
||||
from model_archiver.model_packaging_utils import ModelExportUtils
|
||||
except ImportError:
|
||||
package_model = None
|
||||
|
||||
|
||||
def mmocr2torchserve(
|
||||
config_file: str,
|
||||
checkpoint_file: str,
|
||||
output_folder: str,
|
||||
model_name: str,
|
||||
model_version: str = '1.0',
|
||||
force: bool = False,
|
||||
):
|
||||
"""Converts MMOCR model (config + checkpoint) to TorchServe `.mar`.
|
||||
|
||||
Args:
|
||||
config_file:
|
||||
In MMOCR config format.
|
||||
The contents vary for each task repository.
|
||||
checkpoint_file:
|
||||
In MMOCR checkpoint format.
|
||||
The contents vary for each task repository.
|
||||
output_folder:
|
||||
Folder where `{model_name}.mar` will be created.
|
||||
The file created will be in TorchServe archive format.
|
||||
model_name:
|
||||
If not None, used for naming the `{model_name}.mar` file
|
||||
that will be created under `output_folder`.
|
||||
If None, `{Path(checkpoint_file).stem}` will be used.
|
||||
model_version:
|
||||
Model's version.
|
||||
force:
|
||||
If True, if there is an existing `{model_name}.mar`
|
||||
file under `output_folder` it will be overwritten.
|
||||
"""
|
||||
mmcv.mkdir_or_exist(output_folder)
|
||||
|
||||
config = mmcv.Config.fromfile(config_file)
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
config.dump(f'{tmpdir}/config.py')
|
||||
|
||||
args = Namespace(
|
||||
**{
|
||||
'model_file': f'{tmpdir}/config.py',
|
||||
'serialized_file': checkpoint_file,
|
||||
'handler': f'{Path(__file__).parent}/mmocr_handler.py',
|
||||
'model_name': model_name or Path(checkpoint_file).stem,
|
||||
'version': model_version,
|
||||
'export_path': output_folder,
|
||||
'force': force,
|
||||
'requirements_file': None,
|
||||
'extra_files': None,
|
||||
'runtime': 'python',
|
||||
'archive_format': 'default'
|
||||
})
|
||||
manifest = ModelExportUtils.generate_manifest_json(args)
|
||||
package_model(args, manifest)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(
|
||||
description='Convert MMOCR models to TorchServe `.mar` format.')
|
||||
parser.add_argument('config', type=str, help='config file path')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
|
||||
parser.add_argument(
|
||||
'--output-folder',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Folder where `{model_name}.mar` will be created.')
|
||||
parser.add_argument(
|
||||
'--model-name',
|
||||
type=str,
|
||||
default=None,
|
||||
help='If not None, used for naming the `{model_name}.mar`'
|
||||
'file that will be created under `output_folder`.'
|
||||
'If None, `{Path(checkpoint_file).stem}` will be used.')
|
||||
parser.add_argument(
|
||||
'--model-version',
|
||||
type=str,
|
||||
default='1.0',
|
||||
help='Number used for versioning.')
|
||||
parser.add_argument(
|
||||
'-f',
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='overwrite the existing `{model_name}.mar`')
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
if package_model is None:
|
||||
raise ImportError('`torch-model-archiver` is required.'
|
||||
'Try: pip install torch-model-archiver')
|
||||
|
||||
mmocr2torchserve(args.config, args.checkpoint, args.output_folder,
|
||||
args.model_name, args.model_version, args.force)
|
|
@ -1,51 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import base64
|
||||
import os
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from ts.torch_handler.base_handler import BaseHandler
|
||||
|
||||
from mmocr.apis import init_detector, model_inference
|
||||
from mmocr.datasets.pipelines import * # NOQA
|
||||
|
||||
|
||||
class MMOCRHandler(BaseHandler):
|
||||
threshold = 0.5
|
||||
|
||||
def initialize(self, context):
|
||||
properties = context.system_properties
|
||||
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.device = torch.device(self.map_location + ':' +
|
||||
str(properties.get('gpu_id')) if torch.cuda.
|
||||
is_available() else self.map_location)
|
||||
self.manifest = context.manifest
|
||||
|
||||
model_dir = properties.get('model_dir')
|
||||
serialized_file = self.manifest['model']['serializedFile']
|
||||
checkpoint = os.path.join(model_dir, serialized_file)
|
||||
self.config_file = os.path.join(model_dir, 'config.py')
|
||||
|
||||
self.model = init_detector(self.config_file, checkpoint, self.device)
|
||||
self.initialized = True
|
||||
|
||||
def preprocess(self, data):
|
||||
images = []
|
||||
|
||||
for row in data:
|
||||
image = row.get('data') or row.get('body')
|
||||
if isinstance(image, str):
|
||||
image = base64.b64decode(image)
|
||||
image = mmcv.imfrombytes(image)
|
||||
images.append(image)
|
||||
|
||||
return images
|
||||
|
||||
def inference(self, data, *args, **kwargs):
|
||||
|
||||
results = model_inference(self.model, data)
|
||||
return results
|
||||
|
||||
def postprocess(self, data):
|
||||
# Format output following the example OCRHandler format
|
||||
return data
|
|
@ -1,294 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Iterable
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate
|
||||
from mmcv.tensorrt import is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
|
||||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||
from mmocr.utils import is_2dlist
|
||||
|
||||
|
||||
def get_GiB(x: int):
|
||||
"""return x GiB."""
|
||||
return x * (1 << 30)
|
||||
|
||||
|
||||
def _prepare_input_img(imgs, test_pipeline: Iterable[dict]):
|
||||
"""Inference image(s) with the detector.
|
||||
|
||||
Args:
|
||||
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
||||
Either image files or loaded images.
|
||||
test_pipeline (Iterable[dict]): Test pipline of configuration.
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
if not isinstance(imgs[0], (np.ndarray, str)):
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
elif isinstance(imgs, (np.ndarray, str)):
|
||||
imgs = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
test_pipeline = replace_ImageToTensor(test_pipeline)
|
||||
test_pipeline = Compose(test_pipeline)
|
||||
|
||||
data = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
# add information into dict
|
||||
datum = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
|
||||
# build the data pipeline
|
||||
datum = test_pipeline(datum)
|
||||
# get tensor from list to stack for batch mode (text detection)
|
||||
data.append(datum)
|
||||
|
||||
if isinstance(data[0]['img'], list) and len(data) > 1:
|
||||
raise Exception('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{len(data)}')
|
||||
|
||||
data = collate(data, samples_per_gpu=len(imgs))
|
||||
|
||||
# process img_metas
|
||||
if isinstance(data['img_metas'], list):
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
else:
|
||||
data['img_metas'] = data['img_metas'].data
|
||||
|
||||
if isinstance(data['img'], list):
|
||||
data['img'] = [img.data for img in data['img']]
|
||||
if isinstance(data['img'][0], list):
|
||||
data['img'] = [img[0] for img in data['img']]
|
||||
else:
|
||||
data['img'] = data['img'].data
|
||||
return data
|
||||
|
||||
|
||||
def onnx2tensorrt(onnx_file: str,
|
||||
model_type: str,
|
||||
trt_file: str,
|
||||
config: dict,
|
||||
input_config: dict,
|
||||
fp16: bool = False,
|
||||
verify: bool = False,
|
||||
show: bool = False,
|
||||
workspace_size: int = 1,
|
||||
verbose: bool = False):
|
||||
import tensorrt as trt
|
||||
min_shape = input_config['min_shape']
|
||||
max_shape = input_config['max_shape']
|
||||
# create trt engine and wrapper
|
||||
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
|
||||
max_workspace_size = get_GiB(workspace_size)
|
||||
trt_engine = onnx2trt(
|
||||
onnx_file,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
|
||||
fp16_mode=fp16,
|
||||
max_workspace_size=max_workspace_size)
|
||||
save_dir, _ = osp.split(trt_file)
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_trt_engine(trt_engine, trt_file)
|
||||
print(f'Successfully created TensorRT engine: {trt_file}')
|
||||
|
||||
if verify:
|
||||
mm_inputs = _prepare_input_img(input_config['input_path'],
|
||||
config.data.test.pipeline)
|
||||
|
||||
imgs = mm_inputs.pop('img')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
|
||||
if isinstance(imgs, list):
|
||||
imgs = imgs[0]
|
||||
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
|
||||
# Get results from ONNXRuntime
|
||||
if model_type == 'det':
|
||||
onnx_model = ONNXRuntimeDetector(onnx_file, config, 0)
|
||||
else:
|
||||
onnx_model = ONNXRuntimeRecognizer(onnx_file, config, 0)
|
||||
onnx_out = onnx_model.simple_test(
|
||||
img_list[0], img_metas[0], rescale=True)
|
||||
|
||||
# Get results from TensorRT
|
||||
if model_type == 'det':
|
||||
trt_model = TensorRTDetector(trt_file, config, 0)
|
||||
else:
|
||||
trt_model = TensorRTRecognizer(trt_file, config, 0)
|
||||
img_list[0] = img_list[0].to(torch.device('cuda:0'))
|
||||
trt_out = trt_model.simple_test(
|
||||
img_list[0], img_metas[0], rescale=True)
|
||||
|
||||
# compare results
|
||||
same_diff = 'same'
|
||||
if model_type == 'recog':
|
||||
for onnx_result, trt_result in zip(onnx_out, trt_out):
|
||||
if onnx_result['text'] != trt_result['text'] or \
|
||||
not np.allclose(
|
||||
np.array(onnx_result['score']),
|
||||
np.array(trt_result['score']),
|
||||
rtol=1e-4,
|
||||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
else:
|
||||
for onnx_result, trt_result in zip(onnx_out[0]['boundary_result'],
|
||||
trt_out[0]['boundary_result']):
|
||||
if not np.allclose(
|
||||
np.array(onnx_result),
|
||||
np.array(trt_result),
|
||||
rtol=1e-4,
|
||||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
print(f'The outputs are {same_diff} between TensorRT and ONNX')
|
||||
|
||||
if show:
|
||||
onnx_img = onnx_model.show_result(
|
||||
input_config['input_path'],
|
||||
onnx_out[0],
|
||||
out_file='onnx.jpg',
|
||||
show=False)
|
||||
trt_img = trt_model.show_result(
|
||||
input_config['input_path'],
|
||||
trt_out[0],
|
||||
out_file='tensorrt.jpg',
|
||||
show=False)
|
||||
if onnx_img is None:
|
||||
onnx_img = cv2.imread(input_config['input_path'])
|
||||
if trt_img is None:
|
||||
trt_img = cv2.imread(input_config['input_path'])
|
||||
|
||||
cv2.imshow('TensorRT', trt_img)
|
||||
cv2.imshow('ONNXRuntime', onnx_img)
|
||||
cv2.waitKey()
|
||||
return
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert MMOCR models from ONNX to TensorRT')
|
||||
parser.add_argument('model_config', help='Config file of the model')
|
||||
parser.add_argument(
|
||||
'model_type',
|
||||
type=str,
|
||||
help='Detection or recognition model to deploy.',
|
||||
choices=['recog', 'det'])
|
||||
parser.add_argument('image_path', type=str, help='Image for test')
|
||||
parser.add_argument('onnx_file', help='Path to the input ONNX model')
|
||||
parser.add_argument(
|
||||
'--trt-file',
|
||||
type=str,
|
||||
help='Path to the output TensorRT engine',
|
||||
default='tmp.trt')
|
||||
parser.add_argument(
|
||||
'--max-shape',
|
||||
type=int,
|
||||
nargs=4,
|
||||
default=[1, 3, 400, 600],
|
||||
help='Maximum shape of model input.')
|
||||
parser.add_argument(
|
||||
'--min-shape',
|
||||
type=int,
|
||||
nargs=4,
|
||||
default=[1, 3, 400, 600],
|
||||
help='Minimum shape of model input.')
|
||||
parser.add_argument(
|
||||
'--workspace-size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Max workspace size in GiB.')
|
||||
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
|
||||
parser.add_argument(
|
||||
'--verify',
|
||||
action='store_true',
|
||||
help='Whether Verify the outputs of ONNXRuntime and TensorRT.',
|
||||
default=True)
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='Whether visiualize outputs of ONNXRuntime and TensorRT.',
|
||||
default=True)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Whether to verbose logging messages while creating \
|
||||
TensorRT engine.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
|
||||
args = parse_args()
|
||||
|
||||
# Following strings of text style are from colorama package
|
||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
||||
white_background = '\x1b[107m'
|
||||
|
||||
msg = white_background + bright_style + red_text
|
||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
||||
msg += reset_style
|
||||
warnings.warn(msg)
|
||||
|
||||
# check arguments
|
||||
assert osp.exists(args.model_config), 'Config {} not found.'.format(
|
||||
args.model_config)
|
||||
assert osp.exists(args.onnx_file), \
|
||||
f'ONNX model {args.onnx_file} not found.'
|
||||
assert args.workspace_size >= 0, 'Workspace size less than 0.'
|
||||
for max_value, min_value in zip(args.max_shape, args.min_shape):
|
||||
assert max_value >= min_value, \
|
||||
'max_shape should be larger than min shape'
|
||||
|
||||
input_config = {
|
||||
'min_shape': args.min_shape,
|
||||
'max_shape': args.max_shape,
|
||||
'input_path': args.image_path
|
||||
}
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.model_config)
|
||||
if cfg.data.test.get('pipeline', None) is None:
|
||||
if is_2dlist(cfg.data.test.datasets):
|
||||
cfg.data.test.pipeline = \
|
||||
cfg.data.test.datasets[0][0].pipeline
|
||||
else:
|
||||
cfg.data.test.pipeline = \
|
||||
cfg.data.test['datasets'][0].pipeline
|
||||
if is_2dlist(cfg.data.test.pipeline):
|
||||
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
|
||||
onnx2tensorrt(
|
||||
args.onnx_file,
|
||||
args.model_type,
|
||||
args.trt_file,
|
||||
cfg,
|
||||
input_config,
|
||||
fp16=args.fp16,
|
||||
verify=args.verify,
|
||||
show=args.show,
|
||||
workspace_size=args.workspace_size,
|
||||
verbose=args.verbose)
|
|
@ -1,368 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.onnx import register_extra_symbolics
|
||||
from mmcv.parallel import collate
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from torch import nn
|
||||
|
||||
from mmocr.apis import init_detector
|
||||
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
|
||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||
from mmocr.utils import is_2dlist
|
||||
|
||||
|
||||
def _convert_batchnorm(module):
|
||||
module_output = module
|
||||
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
module_output.weight.data = module.weight.data.clone().detach()
|
||||
module_output.bias.data = module.bias.data.clone().detach()
|
||||
# keep requires_grad unchanged
|
||||
module_output.weight.requires_grad = module.weight.requires_grad
|
||||
module_output.bias.requires_grad = module.bias.requires_grad
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, _convert_batchnorm(child))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
def _prepare_data(cfg, imgs):
|
||||
"""Inference image(s) with the detector.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded detector.
|
||||
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
||||
Either image files or loaded images.
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
if not isinstance(imgs[0], (np.ndarray, str)):
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
elif isinstance(imgs, (np.ndarray, str)):
|
||||
imgs = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
is_ndarray = isinstance(imgs[0], np.ndarray)
|
||||
|
||||
if is_ndarray:
|
||||
cfg = cfg.copy()
|
||||
# set loading pipeline type
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
|
||||
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
|
||||
data = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
if is_ndarray:
|
||||
# directly add img
|
||||
datum = dict(img=img)
|
||||
else:
|
||||
# add information into dict
|
||||
datum = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
|
||||
# build the data pipeline
|
||||
datum = test_pipeline(datum)
|
||||
# get tensor from list to stack for batch mode (text detection)
|
||||
data.append(datum)
|
||||
|
||||
if isinstance(data[0]['img'], list) and len(data) > 1:
|
||||
raise Exception('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{len(data)}')
|
||||
|
||||
data = collate(data, samples_per_gpu=len(imgs))
|
||||
|
||||
# process img_metas
|
||||
if isinstance(data['img_metas'], list):
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
else:
|
||||
data['img_metas'] = data['img_metas'].data
|
||||
|
||||
if isinstance(data['img'], list):
|
||||
data['img'] = [img.data for img in data['img']]
|
||||
if isinstance(data['img'][0], list):
|
||||
data['img'] = [img[0] for img in data['img']]
|
||||
else:
|
||||
data['img'] = data['img'].data
|
||||
return data
|
||||
|
||||
|
||||
def pytorch2onnx(model: nn.Module,
|
||||
model_type: str,
|
||||
img_path: str,
|
||||
verbose: bool = False,
|
||||
show: bool = False,
|
||||
opset_version: int = 11,
|
||||
output_file: str = 'tmp.onnx',
|
||||
verify: bool = False,
|
||||
dynamic_export: bool = False,
|
||||
device_id: int = 0):
|
||||
"""Export PyTorch model to ONNX model and verify the outputs are same
|
||||
between PyTorch and ONNX.
|
||||
|
||||
Args:
|
||||
model (nn.Module): PyTorch model we want to export.
|
||||
model_type (str): Model type, detection or recognition model.
|
||||
img_path (str): We need to use this input to execute the model.
|
||||
opset_version (int): The onnx op version. Default: 11.
|
||||
verbose (bool): Whether print the computation graph. Default: False.
|
||||
show (bool): Whether visialize final results. Default: False.
|
||||
output_file (string): The path to where we store the output ONNX model.
|
||||
Default: `tmp.onnx`.
|
||||
verify (bool): Whether compare the outputs between PyTorch and ONNX.
|
||||
Default: False.
|
||||
dynamic_export (bool): Whether apply dynamic export.
|
||||
Default: False.
|
||||
device_id (id): Device id to place model and data.
|
||||
Default: 0
|
||||
"""
|
||||
device = torch.device(type='cuda', index=device_id)
|
||||
model.to(device).eval()
|
||||
_convert_batchnorm(model)
|
||||
|
||||
# prepare inputs
|
||||
mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path)
|
||||
imgs = mm_inputs.pop('img')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
|
||||
if isinstance(imgs, list):
|
||||
imgs = imgs[0]
|
||||
|
||||
img_list = [img[None, :].to(device) for img in imgs]
|
||||
|
||||
origin_forward = model.forward
|
||||
if (model_type == 'det'):
|
||||
model.forward = partial(
|
||||
model.simple_test, img_metas=img_metas, rescale=True)
|
||||
else:
|
||||
model.forward = partial(
|
||||
model.forward,
|
||||
img_metas=img_metas,
|
||||
return_loss=False,
|
||||
rescale=True)
|
||||
|
||||
# pytorch has some bug in pytorch1.3, we have to fix it
|
||||
# by replacing these existing op
|
||||
register_extra_symbolics(opset_version)
|
||||
dynamic_axes = None
|
||||
if dynamic_export and model_type == 'det':
|
||||
dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'output': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
}
|
||||
}
|
||||
elif dynamic_export and model_type == 'recog':
|
||||
dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
3: 'width'
|
||||
},
|
||||
'output': {
|
||||
0: 'batch',
|
||||
1: 'seq_len',
|
||||
2: 'num_classes'
|
||||
}
|
||||
}
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
model, (img_list[0], ),
|
||||
output_file,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
verbose=verbose,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes)
|
||||
print(f'Successfully exported ONNX model: {output_file}')
|
||||
if verify:
|
||||
# check by onnx
|
||||
import onnx
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5)
|
||||
if dynamic_export:
|
||||
# scale image for dynamic shape test
|
||||
img_list = [
|
||||
nn.functional.interpolate(_, scale_factor=scale_factor)
|
||||
for _ in img_list
|
||||
]
|
||||
if model_type == 'det':
|
||||
img_metas[0][0][
|
||||
'scale_factor'] = img_metas[0][0]['scale_factor'] * (
|
||||
scale_factor * 2)
|
||||
|
||||
# check the numerical value
|
||||
# get pytorch output
|
||||
with torch.no_grad():
|
||||
model.forward = origin_forward
|
||||
pytorch_out = model.simple_test(
|
||||
img_list[0], img_metas[0], rescale=True)
|
||||
|
||||
# get onnx output
|
||||
if model_type == 'det':
|
||||
onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id)
|
||||
else:
|
||||
onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg,
|
||||
device_id)
|
||||
onnx_out = onnx_model.simple_test(
|
||||
img_list[0], img_metas[0], rescale=True)
|
||||
|
||||
# compare results
|
||||
same_diff = 'same'
|
||||
if model_type == 'recog':
|
||||
for onnx_result, pytorch_result in zip(onnx_out, pytorch_out):
|
||||
if onnx_result['text'] != pytorch_result[
|
||||
'text'] or not np.allclose(
|
||||
np.array(onnx_result['score']),
|
||||
np.array(pytorch_result['score']),
|
||||
rtol=1e-4,
|
||||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
else:
|
||||
for onnx_result, pytorch_result in zip(
|
||||
onnx_out[0]['boundary_result'],
|
||||
pytorch_out[0]['boundary_result']):
|
||||
if not np.allclose(
|
||||
np.array(onnx_result),
|
||||
np.array(pytorch_result),
|
||||
rtol=1e-4,
|
||||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
print(f'The outputs are {same_diff} between PyTorch and ONNX')
|
||||
|
||||
if show:
|
||||
onnx_img = onnx_model.show_result(
|
||||
img_path, onnx_out[0], out_file='onnx.jpg', show=False)
|
||||
pytorch_img = model.show_result(
|
||||
img_path, pytorch_out[0], out_file='pytorch.jpg', show=False)
|
||||
if onnx_img is None:
|
||||
onnx_img = cv2.imread(img_path)
|
||||
if pytorch_img is None:
|
||||
pytorch_img = cv2.imread(img_path)
|
||||
|
||||
cv2.imshow('PyTorch', pytorch_img)
|
||||
cv2.imshow('ONNXRuntime', onnx_img)
|
||||
cv2.waitKey()
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(
|
||||
description='Convert MMOCR models from pytorch to ONNX')
|
||||
parser.add_argument('model_config', type=str, help='Config file.')
|
||||
parser.add_argument(
|
||||
'model_ckpt', type=str, help='Checkpint file (local or url).')
|
||||
parser.add_argument(
|
||||
'model_type',
|
||||
type=str,
|
||||
help='Detection or recognition model to deploy.',
|
||||
choices=['recog', 'det'])
|
||||
parser.add_argument('image_path', type=str, help='Input Image file.')
|
||||
parser.add_argument(
|
||||
'--output-file',
|
||||
type=str,
|
||||
help='Output file name of the onnx model.',
|
||||
default='tmp.onnx')
|
||||
parser.add_argument(
|
||||
'--device-id', default=0, help='Device used for inference.')
|
||||
parser.add_argument(
|
||||
'--opset-version',
|
||||
type=int,
|
||||
help='ONNX opset version, default to 11.',
|
||||
default=11)
|
||||
parser.add_argument(
|
||||
'--verify',
|
||||
action='store_true',
|
||||
help='Whether verify the outputs of onnx and pytorch are same.',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Whether print the computation graph.',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='Whether visualize final output.',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--dynamic-export',
|
||||
action='store_true',
|
||||
help='Whether dynamically export onnx model.',
|
||||
default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Following strings of text style are from colorama package
|
||||
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
||||
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
||||
white_background = '\x1b[107m'
|
||||
|
||||
msg = white_background + bright_style + red_text
|
||||
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
||||
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
||||
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
||||
msg += reset_style
|
||||
warnings.warn(msg)
|
||||
|
||||
device = torch.device(type='cuda', index=args.device_id)
|
||||
|
||||
# build model
|
||||
model = init_detector(args.model_config, args.model_ckpt, device=device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test.get('pipeline', None) is None:
|
||||
if is_2dlist(model.cfg.data.test.datasets):
|
||||
model.cfg.data.test.pipeline = \
|
||||
model.cfg.data.test.datasets[0][0].pipeline
|
||||
else:
|
||||
model.cfg.data.test.pipeline = \
|
||||
model.cfg.data.test['datasets'][0].pipeline
|
||||
if is_2dlist(model.cfg.data.test.pipeline):
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0]
|
||||
|
||||
pytorch2onnx(
|
||||
model,
|
||||
model_type=args.model_type,
|
||||
output_file=args.output_file,
|
||||
img_path=args.image_path,
|
||||
opset_version=args.opset_version,
|
||||
verify=args.verify,
|
||||
verbose=args.verbose,
|
||||
show=args.show,
|
||||
device_id=args.device_id,
|
||||
dynamic_export=args.dynamic_export)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,63 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from mmocr.apis import init_detector, model_inference
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('img', help='Image file')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument('model_name', help='The model name in the server')
|
||||
parser.add_argument(
|
||||
'--inference-addr',
|
||||
default='127.0.0.1:8080',
|
||||
help='Address and port of the inference server')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.5, help='bbox score threshold')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
# test a single image
|
||||
model_results = model_inference(model, args.img)
|
||||
model.show_result(
|
||||
args.img,
|
||||
model_results,
|
||||
win_name='model_results',
|
||||
show=True,
|
||||
score_thr=args.score_thr)
|
||||
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
|
||||
with open(args.img, 'rb') as image:
|
||||
response = requests.post(url, image)
|
||||
serve_results = response.json()
|
||||
model.show_result(
|
||||
args.img,
|
||||
serve_results,
|
||||
show=True,
|
||||
win_name='serve_results',
|
||||
score_thr=args.score_thr)
|
||||
assert serve_results.keys() == model_results.keys()
|
||||
for key in serve_results.keys():
|
||||
for model_result, serve_result in zip(model_results[key],
|
||||
serve_results[key]):
|
||||
if isinstance(model_result[0], (int, float)):
|
||||
assert np.allclose(model_result, serve_result)
|
||||
elif isinstance(model_result[0], str):
|
||||
assert model_result == serve_result
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
Loading…
Reference in New Issue