mmdeploy/tests/test_codebase/test_mmedit/test_mmedit_models.py

89 lines
2.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from typing import Dict
import mmcv
import onnx
import pytest
import torch
from mmdeploy.codebase import import_codebase
from mmdeploy.core import RewriterContext
from mmdeploy.utils import Backend, Codebase, get_onnx_config
try:
import_codebase(Codebase.MMEDIT)
except ImportError:
pytest.skip(
f'{Codebase.MMEDIT} is not installed.', allow_module_level=True)
img = torch.rand(1, 3, 4, 4)
model_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
deploy_cfg = mmcv.Config(
dict(
codebase_config=dict(
type='mmedit',
task='SuperResolution',
),
backend_config=dict(
type='tensorrt',
common_config=dict(fp16_mode=False, max_workspace_size=1 << 10),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 4, 4],
opt_shape=[1, 3, 4, 4],
max_shape=[1, 3, 4, 4])))
]),
ir_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file=model_file,
input_shape=None,
input_names=['input'],
output_names=['output'])))
def test_srcnn():
from mmedit.models.backbones.sr_backbones import SRCNN
pytorch_model = SRCNN()
model_inputs = {'x': img}
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
onnx_cfg = get_onnx_config(deploy_cfg)
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
dynamic_axes = onnx_cfg.get('dynamic_axes', None)
if dynamic_axes is not None and not isinstance(dynamic_axes, Dict):
dynamic_axes = zip(input_names, dynamic_axes)
with RewriterContext(
cfg=deploy_cfg, backend=Backend.TENSORRT.value), torch.no_grad():
torch.onnx.export(
pytorch_model,
tuple([v for k, v in model_inputs.items()]),
onnx_file_path,
export_params=True,
input_names=input_names,
output_names=None,
opset_version=11,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=False)
# The result should be different due to the rewrite.
# So we only check if the file exists
assert osp.exists(onnx_file_path)
model = onnx.load(onnx_file_path)
assert model is not None
try:
onnx.checker.check_model(model)
except onnx.checker.ValidationError:
assert False