Remove useless

pull/1178/head
wangxinyu 2022-07-11 10:03:01 +00:00 committed by gaotongxiao
parent d4dbad56ee
commit de78a8839f
18 changed files with 7 additions and 1658 deletions

View File

@ -10,16 +10,6 @@ mmocr/models/textrecog/recognizers/encode_decode_recognizer.py
mmocr/datasets/pipelines/transforms.py
mmocr/datasets/pipelines/dbnet_transforms.py
# will be deleted
mmocr/models/textdet/heads/head_mixin.py
# They will be removed later all det models have been refactored
mmocr/models/common/detectors/single_stage.py
mmocr/models/textdet/detectors/text_detector_mixin.py
# It will be covered by tests of any det model implemented in future
mmocr/models/textdet/detectors/single_stage_text_detector.py
# It will be removed after all utils are moved to mmocr.utils
mmocr/core/evaluation/utils.py
mmocr/models/textdet/postprocessors/utils.py

View File

@ -1,8 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .deploy_utils import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
TensorRTDetector, TensorRTRecognizer)
__all__ = [
'ONNXRuntimeRecognizer', 'ONNXRuntimeDetector', 'TensorRTDetector',
'TensorRTRecognizer'
]

View File

@ -1,328 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Any, Iterable
import numpy as np
import torch
from mmocr.models.textdet.detectors.single_stage_text_detector import \
SingleStageTextDetector
from mmocr.models.textdet.detectors.text_detector_mixin import \
TextDetectorMixin
from mmocr.models.textrecog.recognizers.encode_decode_recognizer import \
EncodeDecodeRecognizer
from mmocr.registry import MODELS
def inference_with_session(sess, io_binding, input_name, output_names,
input_tensor):
device_type = input_tensor.device.type
device_id = input_tensor.device.index
device_id = 0 if device_id is None else device_id
io_binding.bind_input(
name=input_name,
device_type=device_type,
device_id=device_id,
element_type=np.float32,
shape=input_tensor.shape,
buffer_ptr=input_tensor.data_ptr())
for name in output_names:
io_binding.bind_output(name)
sess.run_with_iobinding(io_binding)
pred = io_binding.copy_outputs_to_cpu()
return pred
@MODELS.register_module()
class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
"""The class for evaluating onnx file of detection."""
def __init__(self,
onnx_file: str,
cfg: Any,
device_id: int,
show_score: bool = False):
if 'type' in cfg.model:
cfg.model.pop('type')
SingleStageTextDetector.__init__(self, **(cfg.model))
TextDetectorMixin.__init__(self, show_score)
import onnxruntime as ort
# get the custom op path
ort_custom_op_path = ''
try:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
session_options = ort.SessionOptions()
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
self.sess = sess
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
for name in self.output_names:
self.io_binding.bind_output(name)
self.cfg = cfg
def forward_train(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self,
img: torch.Tensor,
img_metas: Iterable,
rescale: bool = False):
onnx_pred = inference_with_session(self.sess, self.io_binding, 'input',
self.output_names, img)
onnx_pred = torch.from_numpy(onnx_pred[0])
if len(img_metas) > 1:
boundaries = [
self.bbox_head.get_boundary(*(onnx_pred[i].unsqueeze(0)),
[img_metas[i]], rescale)
for i in range(len(img_metas))
]
else:
boundaries = [
self.bbox_head.get_boundary(*onnx_pred, img_metas, rescale)
]
return boundaries
@MODELS.register_module()
class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
"""The class for evaluating onnx file of recognition."""
def __init__(self,
onnx_file: str,
cfg: Any,
device_id: int,
show_score: bool = False):
if 'type' in cfg.model:
cfg.model.pop('type')
EncodeDecodeRecognizer.__init__(self, **(cfg.model))
import onnxruntime as ort
# get the custom op path
ort_custom_op_path = ''
try:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
session_options = ort.SessionOptions()
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
self.sess = sess
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
for name in self.output_names:
self.io_binding.bind_output(name)
self.cfg = cfg
def forward_train(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def aug_test(self, imgs, img_metas, **kwargs):
if isinstance(imgs, list):
for idx, each_img in enumerate(imgs):
if each_img.dim() == 3:
imgs[idx] = each_img.unsqueeze(0)
imgs = imgs[0] # avoid aug_test
img_metas = img_metas[0]
else:
if len(img_metas) == 1 and isinstance(img_metas[0], list):
img_metas = img_metas[0]
return self.simple_test(imgs, img_metas=img_metas)
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self,
img: torch.Tensor,
img_metas: Iterable,
rescale: bool = False):
"""Test function.
Args:
imgs (torch.Tensor): Image input tensor.
img_metas (list[dict]): List of image information.
Returns:
list[str]: Text label result of each image.
"""
onnx_pred = inference_with_session(self.sess, self.io_binding, 'input',
self.output_names, img)
onnx_pred = torch.from_numpy(onnx_pred[0])
label_indexes, label_scores = self.label_convertor.tensor2idx(
onnx_pred, img_metas)
label_strings = self.label_convertor.idx2str(label_indexes)
# flatten batch results
results = []
for string, score in zip(label_strings, label_scores):
results.append(dict(text=string, score=score))
return results
@MODELS.register_module()
class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
"""The class for evaluating TensorRT file of detection."""
def __init__(self,
trt_file: str,
cfg: Any,
device_id: int,
show_score: bool = False):
if 'type' in cfg.model:
cfg.model.pop('type')
SingleStageTextDetector.__init__(self, **(cfg.model))
TextDetectorMixin.__init__(self, show_score)
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWrapper(
trt_file, input_names=['input'], output_names=['output'])
self.model = model
self.device_id = device_id
self.cfg = cfg
def forward_train(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self,
img: torch.Tensor,
img_metas: Iterable,
rescale: bool = False):
with torch.cuda.device(self.device_id), torch.no_grad():
trt_pred = self.model({'input': img})['output']
if len(img_metas) > 1:
boundaries = [
self.bbox_head.get_boundary(*(trt_pred[i].unsqueeze(0)),
[img_metas[i]], rescale)
for i in range(len(img_metas))
]
else:
boundaries = [
self.bbox_head.get_boundary(*trt_pred, img_metas, rescale)
]
return boundaries
@MODELS.register_module()
class TensorRTRecognizer(EncodeDecodeRecognizer):
"""The class for evaluating TensorRT file of recognition."""
def __init__(self,
trt_file: str,
cfg: Any,
device_id: int,
show_score: bool = False):
if 'type' in cfg.model:
cfg.model.pop('type')
EncodeDecodeRecognizer.__init__(self, **(cfg.model))
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWrapper(
trt_file, input_names=['input'], output_names=['output'])
self.model = model
self.device_id = device_id
self.cfg = cfg
def forward_train(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def aug_test(self, imgs, img_metas, **kwargs):
if isinstance(imgs, list):
for idx, each_img in enumerate(imgs):
if each_img.dim() == 3:
imgs[idx] = each_img.unsqueeze(0)
imgs = imgs[0] # avoid aug_test
img_metas = img_metas[0]
else:
if len(img_metas) == 1 and isinstance(img_metas[0], list):
img_metas = img_metas[0]
return self.simple_test(imgs, img_metas=img_metas)
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def simple_test(self,
img: torch.Tensor,
img_metas: Iterable,
rescale: bool = False):
"""Test function.
Args:
imgs (torch.Tensor): Image input tensor.
img_metas (list[dict]): List of image information.
Returns:
list[str]: Text label result of each image.
"""
with torch.cuda.device(self.device_id), torch.no_grad():
trt_pred = self.model({'input': img})['output']
label_indexes, label_scores = self.label_convertor.tensor2idx(
trt_pred, img_metas)
label_strings = self.label_convertor.idx2str(label_indexes)
# flatten batch results
results = []
for string, score in zip(label_strings, label_scores):
results.append(dict(text=string, score=score))
return results

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .single_stage import SingleStageDetector
__all__ = ['SingleStageDetector']

View File

@ -1,39 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmdet.models.detectors import \
SingleStageDetector as MMDET_SingleStageDetector
from mmocr.registry import MODELS
@MODELS.register_module()
class SingleStageDetector(MMDET_SingleStageDetector):
"""Base class for single-stage detectors.
Single-stage detectors directly and densely predict bounding boxes on the
output features of the backbone+neck.
"""
# TODO: Remove this class as SDGMR has been refactored
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(MMDET_SingleStageDetector, self).__init__(init_cfg=init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg

View File

@ -2,14 +2,12 @@
from .dbnet import DBNet
from .drrg import DRRG
from .fcenet import FCENet
from .ocr_mask_rcnn import OCRMaskRCNN
from .panet import PANet
from .psenet import PSENet
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
from .textsnake import TextSnake
__all__ = [
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
'PANet', 'PSENet', 'TextSnake', 'FCENet', 'DRRG'
'SingleStageTextDetector', 'DBNet', 'PANet', 'PSENet', 'TextSnake',
'FCENet', 'DRRG'
]

View File

@ -1,69 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.detectors import MaskRCNN
from mmocr.core import seg2boundary
from mmocr.registry import MODELS
from .text_detector_mixin import TextDetectorMixin
@MODELS.register_module()
class OCRMaskRCNN(TextDetectorMixin, MaskRCNN):
"""Mask RCNN tailored for OCR."""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None,
text_repr_type='quad',
show_score=False,
init_cfg=None):
TextDetectorMixin.__init__(self, show_score)
MaskRCNN.__init__(
self,
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
assert text_repr_type in ['quad', 'poly']
self.text_repr_type = text_repr_type
def get_boundary(self, results):
"""Convert segmentation into text boundaries.
Args:
results (tuple): The result tuple. The first element is
segmentation while the second is its scores.
Returns:
dict: A result dict containing 'boundary_result'.
"""
assert isinstance(results, tuple)
instance_num = len(results[1][0])
boundaries = []
for i in range(instance_num):
seg = results[1][0][i]
score = results[0][0][i][-1]
boundary = seg2boundary(seg, self.text_repr_type, score)
if boundary is not None:
boundaries.append(boundary)
results = dict(boundary_result=boundaries)
return results
def simple_test(self, img, img_metas, proposals=None, rescale=False):
results = super().simple_test(img, img_metas, proposals, rescale)
boundaries = self.get_boundary(results[0])
boundaries = boundaries if isinstance(boundaries,
list) else [boundaries]
return boundaries

View File

@ -1,82 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import mmcv
from mmocr.core import imshow_pred_boundary
# TODO: delete this
class TextDetectorMixin:
"""Base class for text detector, only to show results.
Args:
show_score (bool): Whether to show text instance score.
"""
def __init__(self, show_score):
self.show_score = show_score
def show_result(self,
img,
result,
score_thr=0.5,
bbox_color='green',
text_color='green',
thickness=1,
font_scale=0.5,
win_name='',
show=False,
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (dict): The results to draw over `img`.
score_thr (float, optional): Minimum score of bboxes to be shown.
Default: 0.3.
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
text_color (str or tuple or :obj:`Color`): Color of texts.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.imshow_pred_boundary`
"""
img = mmcv.imread(img)
img = img.copy()
boundaries = None
labels = None
if 'boundary_result' in result.keys():
boundaries = result['boundary_result']
labels = [0] * len(boundaries)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
# draw bounding boxes
if boundaries is not None:
imshow_pred_boundary(
img,
boundaries,
labels,
score_thr=score_thr,
boundary_color=bbox_color,
text_color=text_color,
thickness=thickness,
font_scale=font_scale,
win_name=win_name,
show=show,
wait_time=wait_time,
out_file=out_file,
show_score=self.show_score)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, '
'result image will be returned')
return img

View File

@ -3,12 +3,11 @@ from .base_textdet_head import BaseTextDetHead
from .db_head import DBHead
from .drrg_head import DRRGHead
from .fce_head import FCEHead
from .head_mixin import HeadMixin
from .pan_head import PANHead
from .pse_head import PSEHead
from .textsnake_head import TextSnakeHead
__all__ = [
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'TextSnakeHead', 'DRRGHead',
'HeadMixin', 'BaseTextDetHead'
'BaseTextDetHead'
]

View File

@ -1,92 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmocr.registry import MODELS
from mmocr.utils import check_argument
# TODO: del this
@MODELS.register_module()
class HeadMixin:
"""Base head class for text detection, including loss calcalation and
postprocess.
Args:
loss (dict): Config to build loss.
postprocessor (dict): Config to build postprocessor.
"""
def __init__(self, loss, postprocessor):
assert isinstance(loss, dict)
assert isinstance(postprocessor, dict)
self.loss_module = MODELS.build(loss)
self.postprocessor = MODELS.build(postprocessor)
def resize_boundary(self, boundaries, scale_factor):
"""Rescale boundaries via scale_factor.
Args:
boundaries (list[list[float]]): The boundary list. Each boundary
has :math:`2k+1` elements with :math:`k>=4`.
scale_factor (ndarray): The scale factor of size :math:`(4,)`.
Returns:
list[list[float]]: The scaled boundaries.
"""
assert check_argument.is_2dlist(boundaries)
assert isinstance(scale_factor, np.ndarray)
assert scale_factor.shape[0] == 4
for b in boundaries:
sz = len(b)
check_argument.valid_boundary(b, True)
b[:sz -
1] = (np.array(b[:sz - 1]) *
(np.tile(scale_factor[:2], int(
(sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
return boundaries
def get_boundary(self, score_maps, img_metas, rescale):
"""Compute text boundaries via post processing.
Args:
score_maps (Tensor): The text score map.
img_metas (dict): The image meta info.
rescale (bool): Rescale boundaries to the original image resolution
if true, and keep the score_maps resolution if false.
Returns:
dict: A dict where boundary results are stored in
``boundary_result``.
"""
assert check_argument.is_type_list(img_metas, dict)
assert isinstance(rescale, bool)
score_maps = score_maps.squeeze()
boundaries = self.postprocessor(score_maps)
if rescale:
boundaries = self.resize_boundary(
boundaries,
1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])
results = dict(
boundary_result=boundaries, filename=img_metas[0]['filename'])
return results
def loss(self, pred_maps, **kwargs):
"""Compute the loss for scene text detection.
Args:
pred_maps (Tensor): The input score maps of shape
:math:`(NxCxHxW)`.
Returns:
dict: The dict for losses.
"""
losses = self.loss_module(pred_maps, self.downsample_ratio, **kwargs)
return losses

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_postprocessor import BasePostprocessor, BaseTextDetPostProcessor
from .base_postprocessor import BaseTextDetPostProcessor
from .db_postprocessor import DBPostprocessor
from .drrg_postprocessor import DRRGPostprocessor
from .fce_postprocessor import FCEPostprocessor
@ -8,7 +8,7 @@ from .pse_postprocessor import PSEPostprocessor
from .textsnake_postprocessor import TextSnakePostprocessor
__all__ = [
'BasePostprocessor', 'PSEPostprocessor', 'PANPostprocessor',
'DBPostprocessor', 'DRRGPostprocessor', 'FCEPostprocessor',
'TextSnakePostprocessor', 'BaseTextDetPostProcessor'
'PSEPostprocessor', 'PANPostprocessor', 'DBPostprocessor',
'DRRGPostprocessor', 'FCEPostprocessor', 'TextSnakePostprocessor',
'BaseTextDetPostProcessor'
]

View File

@ -8,26 +8,6 @@ from mmocr.core import TextDetDataSample
from mmocr.utils import boundary_iou, is_type_list, rescale_polygons
class BasePostprocessor:
"""Deprecated.
TODO: remove this class when all det postprocessors are
refactored
"""
def __init__(self, text_repr_type='poly'):
assert text_repr_type in ['poly', 'quad'
], f'Invalid text repr type {text_repr_type}'
self.text_repr_type = text_repr_type
def is_valid_instance(self, area, confidence, area_thresh,
confidence_thresh):
"""If the area is a valid instance."""
return bool(area >= area_thresh and confidence > confidence_thresh)
class BaseTextDetPostProcessor:
"""Base postprocessor for text detection models.

View File

@ -1,110 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import warnings
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import get_dist_info
from mmdet.apis import single_gpu_test
from mmocr.apis.inference import disable_text_recog_aug_test
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
TensorRTDetector, TensorRTRecognizer)
from mmocr.datasets import build_dataloader
from mmocr.registry import DATASETS
def parse_args():
parser = argparse.ArgumentParser(
description='MMOCR test (and eval) a onnx or tensorrt model.')
parser.add_argument('model_config', type=str, help='Config file.')
parser.add_argument(
'model_file', type=str, help='Input file name for evaluation.')
parser.add_argument(
'model_type',
type=str,
help='Detection or recognition model to deploy.',
choices=['recog', 'det'])
parser.add_argument(
'backend',
type=str,
help='Which backend to test, TensorRT or ONNXRuntime.',
choices=['TensorRT', 'ONNXRuntime'])
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='The evaluation metrics, which depends on the dataset, e.g.,'
'"bbox", "seg", "proposal" for COCO, and "mAP", "recall" for'
'PASCAL VOC.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
args = parser.parse_args()
return args
def main():
args = parse_args()
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
if args.device == 'cpu':
args.device = None
cfg = Config.fromfile(args.model_config)
# build the model
if args.model_type == 'det':
if args.backend == 'TensorRT':
model = TensorRTDetector(args.model_file, cfg, 0)
else:
model = ONNXRuntimeDetector(args.model_file, cfg, 0)
else:
if args.backend == 'TensorRT':
model = TensorRTRecognizer(args.model_file, cfg, 0)
else:
model = ONNXRuntimeRecognizer(args.model_file, cfg, 0)
# build the dataloader
samples_per_gpu = 1
cfg = disable_text_recog_aug_test(cfg)
dataset = DATASETS.build(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False)
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
rank, _ = get_dist_info()
if rank == 0:
kwargs = {}
if args.eval:
eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.eval, **kwargs))
print(dataset.evaluate(outputs, **eval_kwargs))
if __name__ == '__main__':
main()

View File

@ -1,110 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
import mmcv
try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None
def mmocr2torchserve(
config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
force: bool = False,
):
"""Converts MMOCR model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file:
In MMOCR config format.
The contents vary for each task repository.
checkpoint_file:
In MMOCR checkpoint format.
The contents vary for each task repository.
output_folder:
Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name:
If not None, used for naming the `{model_name}.mar` file
that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version:
Model's version.
force:
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
"""
mmcv.mkdir_or_exist(output_folder)
config = mmcv.Config.fromfile(config_file)
with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')
args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler': f'{Path(__file__).parent}/mmocr_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)
def parse_args():
parser = ArgumentParser(
description='Convert MMOCR models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')
mmocr2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.force)

View File

@ -1,51 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import base64
import os
import mmcv
import torch
from ts.torch_handler.base_handler import BaseHandler
from mmocr.apis import init_detector, model_inference
from mmocr.datasets.pipelines import * # NOQA
class MMOCRHandler(BaseHandler):
threshold = 0.5
def initialize(self, context):
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest
model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')
self.model = init_detector(self.config_file, checkpoint, self.device)
self.initialized = True
def preprocess(self, data):
images = []
for row in data:
image = row.get('data') or row.get('body')
if isinstance(image, str):
image = base64.b64decode(image)
image = mmcv.imfrombytes(image)
images.append(image)
return images
def inference(self, data, *args, **kwargs):
results = model_inference(self.model, data)
return results
def postprocess(self, data):
# Format output following the example OCRHandler format
return data

View File

@ -1,294 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import warnings
from typing import Iterable
import cv2
import mmcv
import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.tensorrt import is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
TensorRTDetector, TensorRTRecognizer)
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
from mmocr.utils import is_2dlist
def get_GiB(x: int):
"""return x GiB."""
return x * (1 << 30)
def _prepare_input_img(imgs, test_pipeline: Iterable[dict]):
"""Inference image(s) with the detector.
Args:
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
test_pipeline (Iterable[dict]): Test pipline of configuration.
Returns:
result (dict): Predicted results.
"""
if isinstance(imgs, (list, tuple)):
if not isinstance(imgs[0], (np.ndarray, str)):
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
else:
raise AssertionError('imgs must be strings or numpy arrays')
test_pipeline = replace_ImageToTensor(test_pipeline)
test_pipeline = Compose(test_pipeline)
data = []
for img in imgs:
# prepare data
# add information into dict
datum = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
datum = test_pipeline(datum)
# get tensor from list to stack for batch mode (text detection)
data.append(datum)
if isinstance(data[0]['img'], list) and len(data) > 1:
raise Exception('aug test does not support '
f'inference with batch size '
f'{len(data)}')
data = collate(data, samples_per_gpu=len(imgs))
# process img_metas
if isinstance(data['img_metas'], list):
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
else:
data['img_metas'] = data['img_metas'].data
if isinstance(data['img'], list):
data['img'] = [img.data for img in data['img']]
if isinstance(data['img'][0], list):
data['img'] = [img[0] for img in data['img']]
else:
data['img'] = data['img'].data
return data
def onnx2tensorrt(onnx_file: str,
model_type: str,
trt_file: str,
config: dict,
input_config: dict,
fp16: bool = False,
verify: bool = False,
show: bool = False,
workspace_size: int = 1,
verbose: bool = False):
import tensorrt as trt
min_shape = input_config['min_shape']
max_shape = input_config['max_shape']
# create trt engine and wrapper
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
max_workspace_size = get_GiB(workspace_size)
trt_engine = onnx2trt(
onnx_file,
opt_shape_dict,
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
fp16_mode=fp16,
max_workspace_size=max_workspace_size)
save_dir, _ = osp.split(trt_file)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
save_trt_engine(trt_engine, trt_file)
print(f'Successfully created TensorRT engine: {trt_file}')
if verify:
mm_inputs = _prepare_input_img(input_config['input_path'],
config.data.test.pipeline)
imgs = mm_inputs.pop('img')
img_metas = mm_inputs.pop('img_metas')
if isinstance(imgs, list):
imgs = imgs[0]
img_list = [img[None, :] for img in imgs]
# Get results from ONNXRuntime
if model_type == 'det':
onnx_model = ONNXRuntimeDetector(onnx_file, config, 0)
else:
onnx_model = ONNXRuntimeRecognizer(onnx_file, config, 0)
onnx_out = onnx_model.simple_test(
img_list[0], img_metas[0], rescale=True)
# Get results from TensorRT
if model_type == 'det':
trt_model = TensorRTDetector(trt_file, config, 0)
else:
trt_model = TensorRTRecognizer(trt_file, config, 0)
img_list[0] = img_list[0].to(torch.device('cuda:0'))
trt_out = trt_model.simple_test(
img_list[0], img_metas[0], rescale=True)
# compare results
same_diff = 'same'
if model_type == 'recog':
for onnx_result, trt_result in zip(onnx_out, trt_out):
if onnx_result['text'] != trt_result['text'] or \
not np.allclose(
np.array(onnx_result['score']),
np.array(trt_result['score']),
rtol=1e-4,
atol=1e-4):
same_diff = 'different'
break
else:
for onnx_result, trt_result in zip(onnx_out[0]['boundary_result'],
trt_out[0]['boundary_result']):
if not np.allclose(
np.array(onnx_result),
np.array(trt_result),
rtol=1e-4,
atol=1e-4):
same_diff = 'different'
break
print(f'The outputs are {same_diff} between TensorRT and ONNX')
if show:
onnx_img = onnx_model.show_result(
input_config['input_path'],
onnx_out[0],
out_file='onnx.jpg',
show=False)
trt_img = trt_model.show_result(
input_config['input_path'],
trt_out[0],
out_file='tensorrt.jpg',
show=False)
if onnx_img is None:
onnx_img = cv2.imread(input_config['input_path'])
if trt_img is None:
trt_img = cv2.imread(input_config['input_path'])
cv2.imshow('TensorRT', trt_img)
cv2.imshow('ONNXRuntime', onnx_img)
cv2.waitKey()
return
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMOCR models from ONNX to TensorRT')
parser.add_argument('model_config', help='Config file of the model')
parser.add_argument(
'model_type',
type=str,
help='Detection or recognition model to deploy.',
choices=['recog', 'det'])
parser.add_argument('image_path', type=str, help='Image for test')
parser.add_argument('onnx_file', help='Path to the input ONNX model')
parser.add_argument(
'--trt-file',
type=str,
help='Path to the output TensorRT engine',
default='tmp.trt')
parser.add_argument(
'--max-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Maximum shape of model input.')
parser.add_argument(
'--min-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Minimum shape of model input.')
parser.add_argument(
'--workspace-size',
type=int,
default=1,
help='Max workspace size in GiB.')
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
parser.add_argument(
'--verify',
action='store_true',
help='Whether Verify the outputs of ONNXRuntime and TensorRT.',
default=True)
parser.add_argument(
'--show',
action='store_true',
help='Whether visiualize outputs of ONNXRuntime and TensorRT.',
default=True)
parser.add_argument(
'--verbose',
action='store_true',
help='Whether to verbose logging messages while creating \
TensorRT engine.')
args = parser.parse_args()
return args
if __name__ == '__main__':
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
args = parse_args()
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
# check arguments
assert osp.exists(args.model_config), 'Config {} not found.'.format(
args.model_config)
assert osp.exists(args.onnx_file), \
f'ONNX model {args.onnx_file} not found.'
assert args.workspace_size >= 0, 'Workspace size less than 0.'
for max_value, min_value in zip(args.max_shape, args.min_shape):
assert max_value >= min_value, \
'max_shape should be larger than min shape'
input_config = {
'min_shape': args.min_shape,
'max_shape': args.max_shape,
'input_path': args.image_path
}
cfg = mmcv.Config.fromfile(args.model_config)
if cfg.data.test.get('pipeline', None) is None:
if is_2dlist(cfg.data.test.datasets):
cfg.data.test.pipeline = \
cfg.data.test.datasets[0][0].pipeline
else:
cfg.data.test.pipeline = \
cfg.data.test['datasets'][0].pipeline
if is_2dlist(cfg.data.test.pipeline):
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
onnx2tensorrt(
args.onnx_file,
args.model_type,
args.trt_file,
cfg,
input_config,
fp16=args.fp16,
verify=args.verify,
show=args.show,
workspace_size=args.workspace_size,
verbose=args.verbose)

View File

@ -1,368 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from argparse import ArgumentParser
from functools import partial
import cv2
import numpy as np
import torch
from mmcv.onnx import register_extra_symbolics
from mmcv.parallel import collate
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from torch import nn
from mmocr.apis import init_detector
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
from mmocr.utils import is_2dlist
def _convert_batchnorm(module):
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def _prepare_data(cfg, imgs):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
Returns:
result (dict): Predicted results.
"""
if isinstance(imgs, (list, tuple)):
if not isinstance(imgs[0], (np.ndarray, str)):
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
else:
raise AssertionError('imgs must be strings or numpy arrays')
is_ndarray = isinstance(imgs[0], np.ndarray)
if is_ndarray:
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
data = []
for img in imgs:
# prepare data
if is_ndarray:
# directly add img
datum = dict(img=img)
else:
# add information into dict
datum = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
datum = test_pipeline(datum)
# get tensor from list to stack for batch mode (text detection)
data.append(datum)
if isinstance(data[0]['img'], list) and len(data) > 1:
raise Exception('aug test does not support '
f'inference with batch size '
f'{len(data)}')
data = collate(data, samples_per_gpu=len(imgs))
# process img_metas
if isinstance(data['img_metas'], list):
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
else:
data['img_metas'] = data['img_metas'].data
if isinstance(data['img'], list):
data['img'] = [img.data for img in data['img']]
if isinstance(data['img'][0], list):
data['img'] = [img[0] for img in data['img']]
else:
data['img'] = data['img'].data
return data
def pytorch2onnx(model: nn.Module,
model_type: str,
img_path: str,
verbose: bool = False,
show: bool = False,
opset_version: int = 11,
output_file: str = 'tmp.onnx',
verify: bool = False,
dynamic_export: bool = False,
device_id: int = 0):
"""Export PyTorch model to ONNX model and verify the outputs are same
between PyTorch and ONNX.
Args:
model (nn.Module): PyTorch model we want to export.
model_type (str): Model type, detection or recognition model.
img_path (str): We need to use this input to execute the model.
opset_version (int): The onnx op version. Default: 11.
verbose (bool): Whether print the computation graph. Default: False.
show (bool): Whether visialize final results. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between PyTorch and ONNX.
Default: False.
dynamic_export (bool): Whether apply dynamic export.
Default: False.
device_id (id): Device id to place model and data.
Default: 0
"""
device = torch.device(type='cuda', index=device_id)
model.to(device).eval()
_convert_batchnorm(model)
# prepare inputs
mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path)
imgs = mm_inputs.pop('img')
img_metas = mm_inputs.pop('img_metas')
if isinstance(imgs, list):
imgs = imgs[0]
img_list = [img[None, :].to(device) for img in imgs]
origin_forward = model.forward
if (model_type == 'det'):
model.forward = partial(
model.simple_test, img_metas=img_metas, rescale=True)
else:
model.forward = partial(
model.forward,
img_metas=img_metas,
return_loss=False,
rescale=True)
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
dynamic_axes = None
if dynamic_export and model_type == 'det':
dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
0: 'batch',
2: 'height',
3: 'width'
}
}
elif dynamic_export and model_type == 'recog':
dynamic_axes = {
'input': {
0: 'batch',
3: 'width'
},
'output': {
0: 'batch',
1: 'seq_len',
2: 'num_classes'
}
}
with torch.no_grad():
torch.onnx.export(
model, (img_list[0], ),
output_file,
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=False,
verbose=verbose,
opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
import onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5)
if dynamic_export:
# scale image for dynamic shape test
img_list = [
nn.functional.interpolate(_, scale_factor=scale_factor)
for _ in img_list
]
if model_type == 'det':
img_metas[0][0][
'scale_factor'] = img_metas[0][0]['scale_factor'] * (
scale_factor * 2)
# check the numerical value
# get pytorch output
with torch.no_grad():
model.forward = origin_forward
pytorch_out = model.simple_test(
img_list[0], img_metas[0], rescale=True)
# get onnx output
if model_type == 'det':
onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id)
else:
onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg,
device_id)
onnx_out = onnx_model.simple_test(
img_list[0], img_metas[0], rescale=True)
# compare results
same_diff = 'same'
if model_type == 'recog':
for onnx_result, pytorch_result in zip(onnx_out, pytorch_out):
if onnx_result['text'] != pytorch_result[
'text'] or not np.allclose(
np.array(onnx_result['score']),
np.array(pytorch_result['score']),
rtol=1e-4,
atol=1e-4):
same_diff = 'different'
break
else:
for onnx_result, pytorch_result in zip(
onnx_out[0]['boundary_result'],
pytorch_out[0]['boundary_result']):
if not np.allclose(
np.array(onnx_result),
np.array(pytorch_result),
rtol=1e-4,
atol=1e-4):
same_diff = 'different'
break
print(f'The outputs are {same_diff} between PyTorch and ONNX')
if show:
onnx_img = onnx_model.show_result(
img_path, onnx_out[0], out_file='onnx.jpg', show=False)
pytorch_img = model.show_result(
img_path, pytorch_out[0], out_file='pytorch.jpg', show=False)
if onnx_img is None:
onnx_img = cv2.imread(img_path)
if pytorch_img is None:
pytorch_img = cv2.imread(img_path)
cv2.imshow('PyTorch', pytorch_img)
cv2.imshow('ONNXRuntime', onnx_img)
cv2.waitKey()
return
def main():
parser = ArgumentParser(
description='Convert MMOCR models from pytorch to ONNX')
parser.add_argument('model_config', type=str, help='Config file.')
parser.add_argument(
'model_ckpt', type=str, help='Checkpint file (local or url).')
parser.add_argument(
'model_type',
type=str,
help='Detection or recognition model to deploy.',
choices=['recog', 'det'])
parser.add_argument('image_path', type=str, help='Input Image file.')
parser.add_argument(
'--output-file',
type=str,
help='Output file name of the onnx model.',
default='tmp.onnx')
parser.add_argument(
'--device-id', default=0, help='Device used for inference.')
parser.add_argument(
'--opset-version',
type=int,
help='ONNX opset version, default to 11.',
default=11)
parser.add_argument(
'--verify',
action='store_true',
help='Whether verify the outputs of onnx and pytorch are same.',
default=False)
parser.add_argument(
'--verbose',
action='store_true',
help='Whether print the computation graph.',
default=False)
parser.add_argument(
'--show',
action='store_true',
help='Whether visualize final output.',
default=False)
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Whether dynamically export onnx model.',
default=False)
args = parser.parse_args()
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
device = torch.device(type='cuda', index=args.device_id)
# build model
model = init_detector(args.model_config, args.model_ckpt, device=device)
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test.get('pipeline', None) is None:
if is_2dlist(model.cfg.data.test.datasets):
model.cfg.data.test.pipeline = \
model.cfg.data.test.datasets[0][0].pipeline
else:
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
if is_2dlist(model.cfg.data.test.pipeline):
model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0]
pytorch2onnx(
model,
model_type=args.model_type,
output_file=args.output_file,
img_path=args.image_path,
opset_version=args.opset_version,
verify=args.verify,
verbose=args.verbose,
show=args.show,
device_id=args.device_id,
dynamic_export=args.dynamic_export)
if __name__ == '__main__':
main()

View File

@ -1,63 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import numpy as np
import requests
from mmocr.apis import init_detector, model_inference
def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.5, help='bbox score threshold')
args = parser.parse_args()
return args
def main(args):
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# test a single image
model_results = model_inference(model, args.img)
model.show_result(
args.img,
model_results,
win_name='model_results',
show=True,
score_thr=args.score_thr)
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
with open(args.img, 'rb') as image:
response = requests.post(url, image)
serve_results = response.json()
model.show_result(
args.img,
serve_results,
show=True,
win_name='serve_results',
score_thr=args.score_thr)
assert serve_results.keys() == model_results.keys()
for key in serve_results.keys():
for model_result, serve_result in zip(model_results[key],
serve_results[key]):
if isinstance(model_result[0], (int, float)):
assert np.allclose(model_result, serve_result)
elif isinstance(model_result[0], str):
assert model_result == serve_result
else:
raise TypeError
if __name__ == '__main__':
args = parse_args()
main(args)