mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* mmcv.Config -> mmengine Config * support mmedit part * add rewriter for BaseEditModels * fix visualizer * mmedit visualization * remove unused code * fix realesrgan * fix trt * support MultiTestLoop; rewriter fix mmediting bugs; fix ut * fix uts * fix mmedit sdk * fix regression test(part) * fix torchscript * part of fix regression test * fix checkenv.py * fix test.py for mmedit2.0 * support for mmedit * fix regression_test * fix check copyright ci * fix isort * fix docformatter * fix yapf * fix tests * fix sdk after 1040 * add a file for ut * fix docformatter * fix export info * fix super_resolution * fix test.py * stage configs * remove unused code * remove rewriter of multitestloop * fix yapf
121 lines
3.6 KiB
Python
121 lines
3.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
from typing import Dict
|
|
|
|
import mmengine
|
|
import onnx
|
|
import torch
|
|
from mmedit.models.editors.srcnn import SRCNNNet
|
|
|
|
from mmdeploy.codebase import import_codebase
|
|
from mmdeploy.core import RewriterContext
|
|
from mmdeploy.utils import Backend, Codebase, get_onnx_config
|
|
|
|
import_codebase(Codebase.MMEDIT)
|
|
|
|
img = torch.rand(1, 3, 4, 4)
|
|
model_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
|
|
|
deploy_cfg = mmengine.Config(
|
|
dict(
|
|
codebase_config=dict(
|
|
type='mmedit',
|
|
task='SuperResolution',
|
|
),
|
|
backend_config=dict(
|
|
type='tensorrt',
|
|
common_config=dict(fp16_mode=False, max_workspace_size=1 << 10),
|
|
model_inputs=[
|
|
dict(
|
|
input_shapes=dict(
|
|
input=dict(
|
|
min_shape=[1, 3, 4, 4],
|
|
opt_shape=[1, 3, 4, 4],
|
|
max_shape=[1, 3, 4, 4])))
|
|
]),
|
|
ir_config=dict(
|
|
type='onnx',
|
|
export_params=True,
|
|
keep_initializers_as_inputs=False,
|
|
opset_version=11,
|
|
save_file=model_file,
|
|
input_shape=None,
|
|
input_names=['input'],
|
|
output_names=['output'])))
|
|
|
|
|
|
def test_base_edit_model_forward():
|
|
from typing import List, Optional
|
|
|
|
from mmedit.models.base_models.base_edit_model import BaseEditModel
|
|
from mmedit.structures import EditDataSample
|
|
|
|
from mmdeploy.codebase.mmedit import models # noqa
|
|
|
|
class DummyBaseEditModel(BaseEditModel):
|
|
|
|
def __init__(self, generator, pixel_loss):
|
|
super().__init__(generator, pixel_loss)
|
|
|
|
def forward(self,
|
|
inputs: torch.Tensor,
|
|
data_samples: Optional[List[EditDataSample]] = None,
|
|
mode: str = 'tensor',
|
|
**kwargs):
|
|
return inputs
|
|
|
|
generator = dict(
|
|
type='SRCNNNet',
|
|
channels=(3, 64, 32, 3),
|
|
kernel_sizes=(9, 1, 5),
|
|
upscale_factor=4)
|
|
pixel_loss = dict(type='L1Loss', loss_weight=1.0, reduction='mean')
|
|
model = DummyBaseEditModel(generator, pixel_loss).eval()
|
|
|
|
model_output = model(input, None, mode='predict')
|
|
|
|
with RewriterContext({}):
|
|
backend_output = model(input)
|
|
|
|
assert model_output == input
|
|
assert backend_output == input
|
|
|
|
|
|
def test_srcnn():
|
|
pytorch_model = SRCNNNet()
|
|
model_inputs = {'x': img}
|
|
|
|
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
|
onnx_cfg = get_onnx_config(deploy_cfg)
|
|
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
|
|
|
|
dynamic_axes = onnx_cfg.get('dynamic_axes', None)
|
|
|
|
if dynamic_axes is not None and not isinstance(dynamic_axes, Dict):
|
|
dynamic_axes = zip(input_names, dynamic_axes)
|
|
|
|
with RewriterContext(
|
|
cfg=deploy_cfg, backend=Backend.TENSORRT.value), torch.no_grad():
|
|
torch.onnx.export(
|
|
pytorch_model,
|
|
tuple([v for k, v in model_inputs.items()]),
|
|
onnx_file_path,
|
|
export_params=True,
|
|
input_names=input_names,
|
|
output_names=None,
|
|
opset_version=11,
|
|
dynamic_axes=dynamic_axes,
|
|
keep_initializers_as_inputs=False)
|
|
|
|
# The result should be different due to the rewrite.
|
|
# So we only check if the file exists
|
|
assert osp.exists(onnx_file_path)
|
|
|
|
model = onnx.load(onnx_file_path)
|
|
assert model is not None
|
|
try:
|
|
onnx.checker.check_model(model)
|
|
except onnx.checker.ValidationError:
|
|
assert False
|