mmdeploy/tests/test_codebase/test_mmedit/test_mmedit_models.py

125 lines
3.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from typing import Dict, List, Optional
import mmengine
import onnx
import pytest
import torch
from mmdeploy.codebase import import_codebase
from mmdeploy.core import RewriterContext
from mmdeploy.utils import Backend, Codebase, get_onnx_config
try:
import_codebase(Codebase.MMEDIT)
except ImportError:
pytest.skip(
f'{Codebase.MMEDIT} is not installed.', allow_module_level=True)
img = torch.rand(1, 3, 4, 4)
model_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
deploy_cfg = mmengine.Config(
dict(
codebase_config=dict(
type='mmedit',
task='SuperResolution',
),
backend_config=dict(
type='tensorrt',
common_config=dict(fp16_mode=False, max_workspace_size=1 << 10),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 4, 4],
opt_shape=[1, 3, 4, 4],
max_shape=[1, 3, 4, 4])))
]),
ir_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file=model_file,
input_shape=None,
input_names=['input'],
output_names=['output'])))
def test_base_edit_model_forward():
from mmedit.models.base_models.base_edit_model import BaseEditModel
from mmedit.structures import EditDataSample
from mmdeploy.codebase.mmedit import models # noqa
class DummyBaseEditModel(BaseEditModel):
def __init__(self, generator, pixel_loss):
super().__init__(generator, pixel_loss)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[EditDataSample]] = None,
mode: str = 'tensor',
**kwargs):
return inputs
generator = dict(
type='SRCNNNet',
channels=(3, 64, 32, 3),
kernel_sizes=(9, 1, 5),
upscale_factor=4)
pixel_loss = dict(type='L1Loss', loss_weight=1.0, reduction='mean')
model = DummyBaseEditModel(generator, pixel_loss).eval()
model_output = model(input, None, mode='predict')
with RewriterContext({}):
backend_output = model(input)
assert model_output == input
assert backend_output == input
def test_srcnn():
from mmedit.models.editors.srcnn import SRCNNNet
pytorch_model = SRCNNNet()
model_inputs = {'x': img}
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
onnx_cfg = get_onnx_config(deploy_cfg)
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
dynamic_axes = onnx_cfg.get('dynamic_axes', None)
if dynamic_axes is not None and not isinstance(dynamic_axes, Dict):
dynamic_axes = zip(input_names, dynamic_axes)
with RewriterContext(
cfg=deploy_cfg, backend=Backend.TENSORRT.value), torch.no_grad():
torch.onnx.export(
pytorch_model,
tuple([v for k, v in model_inputs.items()]),
onnx_file_path,
export_params=True,
input_names=input_names,
output_names=None,
opset_version=11,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=False)
# The result should be different due to the rewrite.
# So we only check if the file exists
assert osp.exists(onnx_file_path)
model = onnx.load(onnx_file_path)
assert model is not None
try:
onnx.checker.check_model(model)
except onnx.checker.ValidationError:
assert False