Fix coreml (#1658)

* fix coreml topk

* update

* fix lint
pull/1670/head
q.yao 2023-01-19 11:42:18 +08:00 committed by GitHub
parent bce276ef24
commit 513b1c3cfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 140 additions and 47 deletions

View File

@ -0,0 +1,11 @@
_base_ = ['./classification_coreml_dynamic-224x224-224x224.py']
ir_config = dict(input_shape=(384, 384))
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 384, 384],
max_shape=[1, 3, 384, 384],
default_shape=[1, 3, 384, 384])))
])

View File

@ -0,0 +1,11 @@
_base_ = ['../_base_/base_torchscript.py', '../../_base_/backends/coreml.py']
ir_config = dict(input_shape=(608, 608))
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 608, 608],
max_shape=[1, 3, 608, 608],
default_shape=[1, 3, 608, 608])))
])

View File

@ -0,0 +1,13 @@
_base_ = [
'../../_base_/torchscript_config.py', '../../_base_/backends/coreml.py'
]
codebase_config = dict(type='mmocr', task='TextRecognition')
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
max_shape=[1, 3, 32, 640],
default_shape=[1, 3, 32, 64])))
])

View File

@ -43,7 +43,7 @@ class CoreMLManager(BaseBackendManager):
bool: True if backend package is installed.
"""
import importlib
return importlib.util.find_spec('coreml') is not None
return importlib.util.find_spec('coremltools') is not None
@classmethod
def get_version(cls) -> str:
@ -53,7 +53,7 @@ class CoreMLManager(BaseBackendManager):
else:
import pkg_resources
try:
return pkg_resources.get_distribution('coreml').version
return pkg_resources.get_distribution('coremltools').version
except Exception:
return 'None'
@ -78,14 +78,46 @@ class CoreMLManager(BaseBackendManager):
Returns:
Seqeuence[str]: Backend files.
"""
from .torchscript2coreml import from_torchscript
from mmdeploy.utils import (get_common_config, get_ir_config,
get_model_inputs, load_config)
from .torchscript2coreml import from_torchscript, get_model_suffix
coreml_files = []
for model_id, torchscript_path in enumerate(ir_files):
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)
from_torchscript(model_id, torchscript_path, output_file_prefix,
deploy_cfg, coreml_files)
deploy_cfg = load_config(deploy_cfg)[0]
common_params = get_common_config(deploy_cfg)
model_params = get_model_inputs(deploy_cfg)[model_id]
final_params = common_params
final_params.update(model_params)
ir_config = get_ir_config(deploy_cfg)
input_names = ir_config.get('input_names', [])
output_names = ir_config.get('output_names', [])
input_shapes = final_params['input_shapes']
compute_precision = final_params.get('compute_precision',
'FLOAT32')
convert_to = deploy_cfg.backend_config.convert_to
minimum_deployment_target = final_params.get(
'minimum_deployment_target', None)
skip_model_load = final_params.get('skip_model_load', False)
from_torchscript(
torchscript_path,
output_file_prefix,
input_names=input_names,
output_names=output_names,
input_shapes=input_shapes,
compute_precision=compute_precision,
convert_to=convert_to,
minimum_deployment_target=minimum_deployment_target,
skip_model_load=skip_model_load)
suffix = get_model_suffix(convert_to)
output_path = output_file_prefix + suffix
coreml_files.append(output_path)
return coreml_files

View File

@ -1,14 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Union
from typing import Dict, Optional, Sequence, Union
import coremltools as ct
import mmcv
import torch
from mmdeploy.utils import (get_common_config, get_model_inputs,
get_root_logger, load_config)
from mmdeploy.utils.config_utils import get_ir_config
from mmdeploy.utils import get_root_logger
try:
# user might need ops from torchvision
@ -50,24 +46,23 @@ def create_shape(name: str, input_shapes: Dict) -> ct.Shape:
return ct.TensorType(shape=shape, name=name)
def from_torchscript(model_id: int,
torchscript_model: Union[str,
def from_torchscript(torchscript_model: Union[str,
torch.jit.RecursiveScriptModule],
output_file_prefix: str, deploy_cfg: Union[str,
mmcv.Config],
backend_files: List[str], **kwargs):
output_file_prefix: str,
input_names: Sequence[str],
output_names: Sequence[str],
input_shapes: Dict[str, Dict],
compute_precision: str = 'FLOAT32',
convert_to: str = 'neuralnetwork',
minimum_deployment_target: Optional[str] = None,
skip_model_load: bool = False):
"""Create a coreml engine from torchscript.
Args:
model_id (int): Index of input model.
torchscript_model (Union[str, torch.jit.RecursiveScriptModule]):
The torchscript model to be converted.
output_file_prefix (str): The output file prefix.
deploy_cfg (str | mmcv.Config): Deployment config.
backend_files (List[str]):
Backend files used by deployment for testing pipeline
"""
try:
from mmdeploy.backend.torchscript import get_ops_path
torch.ops.load_library(get_ops_path())
@ -80,40 +75,22 @@ def from_torchscript(model_id: int,
if isinstance(torchscript_model, str):
torchscript_model = torch.jit.load(torchscript_model)
deploy_cfg = load_config(deploy_cfg)[0]
common_params = get_common_config(deploy_cfg)
model_params = get_model_inputs(deploy_cfg)[model_id]
final_params = common_params
final_params.update(model_params)
ir_config = get_ir_config(deploy_cfg)
input_names = ir_config.get('input_names', [])
input_shapes = final_params['input_shapes']
inputs = []
for name in input_names:
shape = create_shape(name, input_shapes[name])
inputs.append(shape)
output_names = ir_config.get('output_names', [])
outputs = []
for name in output_names:
outputs.append(ct.TensorType(name=name))
convert_to = deploy_cfg.backend_config.convert_to
if convert_to == 'neuralnetwork':
# Compute precision must be None for neuralnetwork conversion
compute_precision = None
else:
compute_precision = ct.precision[final_params.get(
'compute_precision', 'FLOAT32')]
minimum_deployment_target = final_params.get('minimum_deployment_target',
None)
compute_precision = ct.precision[compute_precision]
mlmodel = ct.convert(
model=torchscript_model,
@ -123,9 +100,8 @@ def from_torchscript(model_id: int,
convert_to=convert_to,
minimum_deployment_target=ct.target[minimum_deployment_target]
if minimum_deployment_target else None,
skip_model_load=final_params.get('skip_model_load', False))
skip_model_load=skip_model_load)
suffix = get_model_suffix(convert_to)
output_path = output_file_prefix + suffix
backend_files.append(output_path)
mlmodel.save(output_path)

View File

@ -131,6 +131,7 @@ def base_dense_head__get_bbox(ctx,
else:
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
bbox_pred, scores, score_factors = gather_topk(
bbox_pred,
scores,

View File

@ -68,3 +68,27 @@ def topk__tensorrt(ctx,
k = TENSORRT_MAX_TOPK
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='coreml')
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.topk', backend='coreml')
def topk__coreml(ctx,
input: torch.Tensor,
k: int,
dim: Optional[int] = None,
largest: bool = True,
sorted: bool = True):
"""Rewrite `topk` for coreml backend.
Cast k to tensor and make sure k is smaller than input.shape[dim].
"""
if dim is None:
dim = int(input.ndim - 1)
size = input.shape[dim]
if not isinstance(k, torch.Tensor):
k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input
k = torch.where(k < size, k, size)
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)

View File

@ -67,7 +67,7 @@ def generate_torchscript_file():
context_info=context_info)
def onnx2backend(backend, onnx_file):
def ir2backend(backend, onnx_file, ts_file):
if backend == Backend.TENSORRT:
from mmdeploy.backend.tensorrt import from_onnx
backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name
@ -143,6 +143,34 @@ def onnx2backend(backend, onnx_file):
onnx_file, lib_file, shape=shape, dtype=dtype, tuner=tuner_dict)
assert osp.exists(lib_file)
return lib_file
elif backend == Backend.TORCHSCRIPT:
return ts_file
elif backend == Backend.COREML:
output_names = ['output']
from mmdeploy.backend.coreml.torchscript2coreml import (
from_torchscript, get_model_suffix)
backend_dir = tempfile.TemporaryDirectory().name
work_dir = backend_dir
torchscript_name = osp.splitext(osp.split(ts_file)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)
convert_to = 'mlprogram'
from_torchscript(
ts_file,
output_file_prefix,
input_names=input_names,
output_names=output_names,
input_shapes=dict(
input=dict(
min_shape=[1, 3, 8, 8],
default_shape=[1, 3, 8, 8],
max_shape=[1, 3, 8, 8])),
convert_to=convert_to)
suffix = get_model_suffix(convert_to)
return output_file_prefix + suffix
else:
raise NotImplementedError(
f'Convert for {backend.value} has not been implemented.')
def create_wrapper(backend, model_files):
@ -186,10 +214,7 @@ ALL_BACKEND.remove(Backend.SDK)
@pytest.mark.parametrize('backend', ALL_BACKEND)
def test_wrapper(backend):
check_backend(backend, True)
if backend == Backend.TORCHSCRIPT:
model_files = ts_file
else:
model_files = onnx2backend(backend, onnx_file)
model_files = ir2backend(backend, onnx_file, ts_file)
assert model_files is not None
wrapper = create_wrapper(backend, model_files)
assert wrapper is not None