89 lines
2.6 KiB
Python
89 lines
2.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
from typing import Dict
|
|
|
|
import mmcv
|
|
import onnx
|
|
import pytest
|
|
import torch
|
|
|
|
from mmdeploy.codebase import import_codebase
|
|
from mmdeploy.core import RewriterContext
|
|
from mmdeploy.utils import Backend, Codebase, get_onnx_config
|
|
|
|
try:
|
|
import_codebase(Codebase.MMEDIT)
|
|
except ImportError:
|
|
pytest.skip(
|
|
f'{Codebase.MMEDIT} is not installed.', allow_module_level=True)
|
|
|
|
img = torch.rand(1, 3, 4, 4)
|
|
model_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
|
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
codebase_config=dict(
|
|
type='mmedit',
|
|
task='SuperResolution',
|
|
),
|
|
backend_config=dict(
|
|
type='tensorrt',
|
|
common_config=dict(fp16_mode=False, max_workspace_size=1 << 10),
|
|
model_inputs=[
|
|
dict(
|
|
input_shapes=dict(
|
|
input=dict(
|
|
min_shape=[1, 3, 4, 4],
|
|
opt_shape=[1, 3, 4, 4],
|
|
max_shape=[1, 3, 4, 4])))
|
|
]),
|
|
ir_config=dict(
|
|
type='onnx',
|
|
export_params=True,
|
|
keep_initializers_as_inputs=False,
|
|
opset_version=11,
|
|
save_file=model_file,
|
|
input_shape=None,
|
|
input_names=['input'],
|
|
output_names=['output'])))
|
|
|
|
|
|
def test_srcnn():
|
|
from mmedit.models.backbones.sr_backbones import SRCNN
|
|
pytorch_model = SRCNN()
|
|
model_inputs = {'x': img}
|
|
|
|
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
|
onnx_cfg = get_onnx_config(deploy_cfg)
|
|
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
|
|
|
|
dynamic_axes = onnx_cfg.get('dynamic_axes', None)
|
|
|
|
if dynamic_axes is not None and not isinstance(dynamic_axes, Dict):
|
|
dynamic_axes = zip(input_names, dynamic_axes)
|
|
|
|
with RewriterContext(
|
|
cfg=deploy_cfg, backend=Backend.TENSORRT.value), torch.no_grad():
|
|
torch.onnx.export(
|
|
pytorch_model,
|
|
tuple([v for k, v in model_inputs.items()]),
|
|
onnx_file_path,
|
|
export_params=True,
|
|
input_names=input_names,
|
|
output_names=None,
|
|
opset_version=11,
|
|
dynamic_axes=dynamic_axes,
|
|
keep_initializers_as_inputs=False)
|
|
|
|
# The result should be different due to the rewrite.
|
|
# So we only check if the file exists
|
|
assert osp.exists(onnx_file_path)
|
|
|
|
model = onnx.load(onnx_file_path)
|
|
assert model is not None
|
|
try:
|
|
onnx.checker.check_model(model)
|
|
except onnx.checker.ValidationError:
|
|
assert False
|