mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* [WIP] Refactor v2.0 (#163) * Refactor backend wrapper * Refactor mmdet.inference * Fix * merge * refactor utils * Use deployer and deploy_model to manage pipeline * Resolve comments * Add a real inference api function * rename wrappers * Set execute to private method * Rename deployer deploy_model * Refactor task * remove type hint * lint * Resolve comments * resolve comments * lint * docstring * [Fix]: Fix bugs in details in refactor branch (#192) * [WIP] Refactor v2.0 (#163) * Refactor backend wrapper * Refactor mmdet.inference * Fix * merge * refactor utils * Use deployer and deploy_model to manage pipeline * Resolve comments * Add a real inference api function * rename wrappers * Set execute to private method * Rename deployer deploy_model * Refactor task * remove type hint * lint * Resolve comments * resolve comments * lint * docstring * Fix errors * lint * resolve comments * fix bugs * conflict * lint and typo * Resolve comment * refactor mmseg (#201) * support mmseg * fix docstring * fix docstring * [Refactor]: Get the count of backend files (#202) * Fix backend files * resolve comments * lint * Fix ncnn * [Refactor]: Refactor folders of mmdet (#200) * Move folders * lint * test object detection model * lint * reset changes * fix openvino * resolve comments * __init__.py * Fix path * [Refactor]: move mmseg (#206) * [Refactor]: Refactor mmedit (#205) * feature mmedit * edit2.0 * edit * refactor mmedit * fix __init__.py * fix __init__ * fix formai * fix comment * fix comment * Fix wrong func_name of ConvFCBBoxHead (#209) * [Refactor]: Refactor mmdet unit test (#207) * Move folders * lint * test object detection model * lint * WIP * remove print * finish unit test * Fix tests * resolve comments * Add mask test * lint * resolve comments * Refine cfg file * Move files * add files * Fix path * [Unittest]: Refine the unit tests in mmdet #214 * [Refactor] refactor mmocr to mmdeploy/codebase (#213) * refactor mmocr to mmdeploy/codebase * fix docstring of show_result * fix docstring of visualize * refine docstring * replace print with logging * refince codes * resolve comments * resolve comments * [Refactor]: mmseg tests (#210) * refactor mmseg tests * rename test_codebase * update * add model.py * fix * [Refactor] Refactor mmcls and the package (#217) * refactor mmcls * fix yapf * fix isort * refactor-mmcls-package * fix print to logging * fix docstrings according to others comments * fix comments * fix comments * fix allentdans comment in pr215 * remove mmocr init * [Refactor] Refactor mmedit tests (#212) * feature mmedit * edit2.0 * edit * refactor mmedit * fix __init__.py * fix __init__ * fix formai * fix comment * fix comment * buff * edit test and code refactor * refactor dir * refactor tests/mmedit * fix docstring * add test coverage * fix lint * fix comment * fix comment * Update typehint (#216) * update type hint * update docstring * update * remove file * fix ppl * Refine get_predefined_partition_cfg * fix tensorrt version > 8 * move parse_cuda_device_id to device.py * Fix cascade * onnx2ncnn docstring Co-authored-by: Yifan Zhou <singlezombie@163.com> Co-authored-by: RunningLeon <maningsheng@sensetime.com> Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
import os.path as osp
|
|
import tempfile
|
|
|
|
import mmcv
|
|
import onnx
|
|
import torch
|
|
from mmedit.models.backbones.sr_backbones import SRCNN
|
|
|
|
from mmdeploy.core import RewriterContext
|
|
from mmdeploy.utils import Backend, get_onnx_config
|
|
|
|
img = torch.rand(1, 3, 4, 4)
|
|
model_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
|
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
codebase_config=dict(
|
|
type='mmedit',
|
|
task='SuperResolution',
|
|
),
|
|
backend_config=dict(
|
|
type='tensorrt',
|
|
common_config=dict(fp16_mode=False, max_workspace_size=1 << 10),
|
|
model_inputs=[
|
|
dict(
|
|
input_shapes=dict(
|
|
input=dict(
|
|
min_shape=[1, 3, 4, 4],
|
|
opt_shape=[1, 3, 4, 4],
|
|
max_shape=[1, 3, 4, 4])))
|
|
]),
|
|
onnx_config=dict(
|
|
type='onnx',
|
|
export_params=True,
|
|
keep_initializers_as_inputs=False,
|
|
opset_version=11,
|
|
save_file=model_file,
|
|
input_shape=None,
|
|
input_names=['input'],
|
|
output_names=['output'])))
|
|
|
|
|
|
def test_srcnn():
|
|
pytorch_model = SRCNN()
|
|
model_inputs = {'x': img}
|
|
|
|
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
|
pytorch2onnx_cfg = get_onnx_config(deploy_cfg)
|
|
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
|
|
with RewriterContext(
|
|
cfg=deploy_cfg, backend=Backend.TENSORRT.value), torch.no_grad():
|
|
torch.onnx.export(
|
|
pytorch_model,
|
|
tuple([v for k, v in model_inputs.items()]),
|
|
onnx_file_path,
|
|
export_params=True,
|
|
input_names=input_names,
|
|
output_names=None,
|
|
opset_version=11,
|
|
dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None),
|
|
keep_initializers_as_inputs=False)
|
|
|
|
# The result should be different due to the rewrite.
|
|
# So we only check if the file exists
|
|
assert osp.exists(onnx_file_path)
|
|
|
|
model = onnx.load(onnx_file_path)
|
|
assert model is not None
|
|
try:
|
|
onnx.checker.check_model(model)
|
|
except onnx.checker.ValidationError:
|
|
assert False
|