mmdeploy/tests/test_codebase/test_mmedit/test_mmedit_models.py
hanrui1sensetime 5c87dd9565
[2.0] Support mmedit 2.0 (#1017)
* mmcv.Config -> mmengine Config

* support mmedit part

* add rewriter for BaseEditModels

* fix visualizer

* mmedit visualization

* remove unused code

* fix realesrgan

* fix trt

* support MultiTestLoop; rewriter fix mmediting bugs; fix ut

* fix uts

* fix mmedit sdk

* fix regression test(part)

* fix torchscript

* part of fix regression test

* fix checkenv.py

* fix test.py for mmedit2.0

* support for mmedit

* fix regression_test

* fix check copyright ci

* fix isort

* fix docformatter

* fix yapf

* fix tests

* fix sdk after 1040

* add a file for ut

* fix docformatter

* fix export info

* fix super_resolution

* fix test.py

* stage configs

* remove unused code

* remove rewriter of multitestloop

* fix yapf
2022-09-20 19:22:55 +08:00

121 lines
3.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from typing import Dict
import mmengine
import onnx
import torch
from mmedit.models.editors.srcnn import SRCNNNet
from mmdeploy.codebase import import_codebase
from mmdeploy.core import RewriterContext
from mmdeploy.utils import Backend, Codebase, get_onnx_config
import_codebase(Codebase.MMEDIT)
img = torch.rand(1, 3, 4, 4)
model_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
deploy_cfg = mmengine.Config(
dict(
codebase_config=dict(
type='mmedit',
task='SuperResolution',
),
backend_config=dict(
type='tensorrt',
common_config=dict(fp16_mode=False, max_workspace_size=1 << 10),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 4, 4],
opt_shape=[1, 3, 4, 4],
max_shape=[1, 3, 4, 4])))
]),
ir_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file=model_file,
input_shape=None,
input_names=['input'],
output_names=['output'])))
def test_base_edit_model_forward():
from typing import List, Optional
from mmedit.models.base_models.base_edit_model import BaseEditModel
from mmedit.structures import EditDataSample
from mmdeploy.codebase.mmedit import models # noqa
class DummyBaseEditModel(BaseEditModel):
def __init__(self, generator, pixel_loss):
super().__init__(generator, pixel_loss)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[EditDataSample]] = None,
mode: str = 'tensor',
**kwargs):
return inputs
generator = dict(
type='SRCNNNet',
channels=(3, 64, 32, 3),
kernel_sizes=(9, 1, 5),
upscale_factor=4)
pixel_loss = dict(type='L1Loss', loss_weight=1.0, reduction='mean')
model = DummyBaseEditModel(generator, pixel_loss).eval()
model_output = model(input, None, mode='predict')
with RewriterContext({}):
backend_output = model(input)
assert model_output == input
assert backend_output == input
def test_srcnn():
pytorch_model = SRCNNNet()
model_inputs = {'x': img}
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
onnx_cfg = get_onnx_config(deploy_cfg)
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
dynamic_axes = onnx_cfg.get('dynamic_axes', None)
if dynamic_axes is not None and not isinstance(dynamic_axes, Dict):
dynamic_axes = zip(input_names, dynamic_axes)
with RewriterContext(
cfg=deploy_cfg, backend=Backend.TENSORRT.value), torch.no_grad():
torch.onnx.export(
pytorch_model,
tuple([v for k, v in model_inputs.items()]),
onnx_file_path,
export_params=True,
input_names=input_names,
output_names=None,
opset_version=11,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=False)
# The result should be different due to the rewrite.
# So we only check if the file exists
assert osp.exists(onnx_file_path)
model = onnx.load(onnx_file_path)
assert model is not None
try:
onnx.checker.check_model(model)
except onnx.checker.ValidationError:
assert False