parent
bce276ef24
commit
513b1c3cfb
|
@ -0,0 +1,11 @@
|
|||
_base_ = ['./classification_coreml_dynamic-224x224-224x224.py']
|
||||
|
||||
ir_config = dict(input_shape=(384, 384))
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 384, 384],
|
||||
max_shape=[1, 3, 384, 384],
|
||||
default_shape=[1, 3, 384, 384])))
|
||||
])
|
|
@ -0,0 +1,11 @@
|
|||
_base_ = ['../_base_/base_torchscript.py', '../../_base_/backends/coreml.py']
|
||||
|
||||
ir_config = dict(input_shape=(608, 608))
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 608, 608],
|
||||
max_shape=[1, 3, 608, 608],
|
||||
default_shape=[1, 3, 608, 608])))
|
||||
])
|
|
@ -0,0 +1,13 @@
|
|||
_base_ = [
|
||||
'../../_base_/torchscript_config.py', '../../_base_/backends/coreml.py'
|
||||
]
|
||||
|
||||
codebase_config = dict(type='mmocr', task='TextRecognition')
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 32, 32],
|
||||
max_shape=[1, 3, 32, 640],
|
||||
default_shape=[1, 3, 32, 64])))
|
||||
])
|
|
@ -43,7 +43,7 @@ class CoreMLManager(BaseBackendManager):
|
|||
bool: True if backend package is installed.
|
||||
"""
|
||||
import importlib
|
||||
return importlib.util.find_spec('coreml') is not None
|
||||
return importlib.util.find_spec('coremltools') is not None
|
||||
|
||||
@classmethod
|
||||
def get_version(cls) -> str:
|
||||
|
@ -53,7 +53,7 @@ class CoreMLManager(BaseBackendManager):
|
|||
else:
|
||||
import pkg_resources
|
||||
try:
|
||||
return pkg_resources.get_distribution('coreml').version
|
||||
return pkg_resources.get_distribution('coremltools').version
|
||||
except Exception:
|
||||
return 'None'
|
||||
|
||||
|
@ -78,14 +78,46 @@ class CoreMLManager(BaseBackendManager):
|
|||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
from .torchscript2coreml import from_torchscript
|
||||
from mmdeploy.utils import (get_common_config, get_ir_config,
|
||||
get_model_inputs, load_config)
|
||||
from .torchscript2coreml import from_torchscript, get_model_suffix
|
||||
|
||||
coreml_files = []
|
||||
for model_id, torchscript_path in enumerate(ir_files):
|
||||
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
|
||||
output_file_prefix = osp.join(work_dir, torchscript_name)
|
||||
|
||||
from_torchscript(model_id, torchscript_path, output_file_prefix,
|
||||
deploy_cfg, coreml_files)
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
|
||||
common_params = get_common_config(deploy_cfg)
|
||||
model_params = get_model_inputs(deploy_cfg)[model_id]
|
||||
|
||||
final_params = common_params
|
||||
final_params.update(model_params)
|
||||
|
||||
ir_config = get_ir_config(deploy_cfg)
|
||||
input_names = ir_config.get('input_names', [])
|
||||
output_names = ir_config.get('output_names', [])
|
||||
input_shapes = final_params['input_shapes']
|
||||
compute_precision = final_params.get('compute_precision',
|
||||
'FLOAT32')
|
||||
convert_to = deploy_cfg.backend_config.convert_to
|
||||
|
||||
minimum_deployment_target = final_params.get(
|
||||
'minimum_deployment_target', None)
|
||||
skip_model_load = final_params.get('skip_model_load', False)
|
||||
from_torchscript(
|
||||
torchscript_path,
|
||||
output_file_prefix,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
input_shapes=input_shapes,
|
||||
compute_precision=compute_precision,
|
||||
convert_to=convert_to,
|
||||
minimum_deployment_target=minimum_deployment_target,
|
||||
skip_model_load=skip_model_load)
|
||||
|
||||
suffix = get_model_suffix(convert_to)
|
||||
output_path = output_file_prefix + suffix
|
||||
coreml_files.append(output_path)
|
||||
return coreml_files
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import coremltools as ct
|
||||
import mmcv
|
||||
import torch
|
||||
|
||||
from mmdeploy.utils import (get_common_config, get_model_inputs,
|
||||
get_root_logger, load_config)
|
||||
from mmdeploy.utils.config_utils import get_ir_config
|
||||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
try:
|
||||
# user might need ops from torchvision
|
||||
|
@ -50,24 +46,23 @@ def create_shape(name: str, input_shapes: Dict) -> ct.Shape:
|
|||
return ct.TensorType(shape=shape, name=name)
|
||||
|
||||
|
||||
def from_torchscript(model_id: int,
|
||||
torchscript_model: Union[str,
|
||||
def from_torchscript(torchscript_model: Union[str,
|
||||
torch.jit.RecursiveScriptModule],
|
||||
output_file_prefix: str, deploy_cfg: Union[str,
|
||||
mmcv.Config],
|
||||
backend_files: List[str], **kwargs):
|
||||
output_file_prefix: str,
|
||||
input_names: Sequence[str],
|
||||
output_names: Sequence[str],
|
||||
input_shapes: Dict[str, Dict],
|
||||
compute_precision: str = 'FLOAT32',
|
||||
convert_to: str = 'neuralnetwork',
|
||||
minimum_deployment_target: Optional[str] = None,
|
||||
skip_model_load: bool = False):
|
||||
"""Create a coreml engine from torchscript.
|
||||
|
||||
Args:
|
||||
model_id (int): Index of input model.
|
||||
torchscript_model (Union[str, torch.jit.RecursiveScriptModule]):
|
||||
The torchscript model to be converted.
|
||||
output_file_prefix (str): The output file prefix.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config.
|
||||
backend_files (List[str]):
|
||||
Backend files used by deployment for testing pipeline
|
||||
"""
|
||||
|
||||
try:
|
||||
from mmdeploy.backend.torchscript import get_ops_path
|
||||
torch.ops.load_library(get_ops_path())
|
||||
|
@ -80,40 +75,22 @@ def from_torchscript(model_id: int,
|
|||
if isinstance(torchscript_model, str):
|
||||
torchscript_model = torch.jit.load(torchscript_model)
|
||||
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
|
||||
common_params = get_common_config(deploy_cfg)
|
||||
model_params = get_model_inputs(deploy_cfg)[model_id]
|
||||
|
||||
final_params = common_params
|
||||
final_params.update(model_params)
|
||||
|
||||
ir_config = get_ir_config(deploy_cfg)
|
||||
|
||||
input_names = ir_config.get('input_names', [])
|
||||
input_shapes = final_params['input_shapes']
|
||||
inputs = []
|
||||
|
||||
for name in input_names:
|
||||
shape = create_shape(name, input_shapes[name])
|
||||
inputs.append(shape)
|
||||
|
||||
output_names = ir_config.get('output_names', [])
|
||||
outputs = []
|
||||
|
||||
for name in output_names:
|
||||
outputs.append(ct.TensorType(name=name))
|
||||
|
||||
convert_to = deploy_cfg.backend_config.convert_to
|
||||
if convert_to == 'neuralnetwork':
|
||||
# Compute precision must be None for neuralnetwork conversion
|
||||
compute_precision = None
|
||||
else:
|
||||
compute_precision = ct.precision[final_params.get(
|
||||
'compute_precision', 'FLOAT32')]
|
||||
|
||||
minimum_deployment_target = final_params.get('minimum_deployment_target',
|
||||
None)
|
||||
compute_precision = ct.precision[compute_precision]
|
||||
|
||||
mlmodel = ct.convert(
|
||||
model=torchscript_model,
|
||||
|
@ -123,9 +100,8 @@ def from_torchscript(model_id: int,
|
|||
convert_to=convert_to,
|
||||
minimum_deployment_target=ct.target[minimum_deployment_target]
|
||||
if minimum_deployment_target else None,
|
||||
skip_model_load=final_params.get('skip_model_load', False))
|
||||
skip_model_load=skip_model_load)
|
||||
|
||||
suffix = get_model_suffix(convert_to)
|
||||
output_path = output_file_prefix + suffix
|
||||
backend_files.append(output_path)
|
||||
mlmodel.save(output_path)
|
||||
|
|
|
@ -131,6 +131,7 @@ def base_dense_head__get_bbox(ctx,
|
|||
else:
|
||||
max_scores, _ = nms_pre_score[..., :-1].max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_topk)
|
||||
|
||||
bbox_pred, scores, score_factors = gather_topk(
|
||||
bbox_pred,
|
||||
scores,
|
||||
|
|
|
@ -68,3 +68,27 @@ def topk__tensorrt(ctx,
|
|||
k = TENSORRT_MAX_TOPK
|
||||
|
||||
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='coreml')
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.topk', backend='coreml')
|
||||
def topk__coreml(ctx,
|
||||
input: torch.Tensor,
|
||||
k: int,
|
||||
dim: Optional[int] = None,
|
||||
largest: bool = True,
|
||||
sorted: bool = True):
|
||||
"""Rewrite `topk` for coreml backend.
|
||||
|
||||
Cast k to tensor and make sure k is smaller than input.shape[dim].
|
||||
"""
|
||||
|
||||
if dim is None:
|
||||
dim = int(input.ndim - 1)
|
||||
size = input.shape[dim]
|
||||
if not isinstance(k, torch.Tensor):
|
||||
k = torch.tensor(k, device=input.device, dtype=torch.long)
|
||||
# Always keep topk op for dynamic input
|
||||
k = torch.where(k < size, k, size)
|
||||
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
|
||||
|
|
|
@ -67,7 +67,7 @@ def generate_torchscript_file():
|
|||
context_info=context_info)
|
||||
|
||||
|
||||
def onnx2backend(backend, onnx_file):
|
||||
def ir2backend(backend, onnx_file, ts_file):
|
||||
if backend == Backend.TENSORRT:
|
||||
from mmdeploy.backend.tensorrt import from_onnx
|
||||
backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name
|
||||
|
@ -143,6 +143,34 @@ def onnx2backend(backend, onnx_file):
|
|||
onnx_file, lib_file, shape=shape, dtype=dtype, tuner=tuner_dict)
|
||||
assert osp.exists(lib_file)
|
||||
return lib_file
|
||||
elif backend == Backend.TORCHSCRIPT:
|
||||
return ts_file
|
||||
elif backend == Backend.COREML:
|
||||
output_names = ['output']
|
||||
from mmdeploy.backend.coreml.torchscript2coreml import (
|
||||
from_torchscript, get_model_suffix)
|
||||
backend_dir = tempfile.TemporaryDirectory().name
|
||||
work_dir = backend_dir
|
||||
torchscript_name = osp.splitext(osp.split(ts_file)[1])[0]
|
||||
output_file_prefix = osp.join(work_dir, torchscript_name)
|
||||
convert_to = 'mlprogram'
|
||||
from_torchscript(
|
||||
ts_file,
|
||||
output_file_prefix,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 8, 8],
|
||||
default_shape=[1, 3, 8, 8],
|
||||
max_shape=[1, 3, 8, 8])),
|
||||
convert_to=convert_to)
|
||||
|
||||
suffix = get_model_suffix(convert_to)
|
||||
return output_file_prefix + suffix
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Convert for {backend.value} has not been implemented.')
|
||||
|
||||
|
||||
def create_wrapper(backend, model_files):
|
||||
|
@ -186,10 +214,7 @@ ALL_BACKEND.remove(Backend.SDK)
|
|||
@pytest.mark.parametrize('backend', ALL_BACKEND)
|
||||
def test_wrapper(backend):
|
||||
check_backend(backend, True)
|
||||
if backend == Backend.TORCHSCRIPT:
|
||||
model_files = ts_file
|
||||
else:
|
||||
model_files = onnx2backend(backend, onnx_file)
|
||||
model_files = ir2backend(backend, onnx_file, ts_file)
|
||||
assert model_files is not None
|
||||
wrapper = create_wrapper(backend, model_files)
|
||||
assert wrapper is not None
|
||||
|
|
Loading…
Reference in New Issue