1
0
mirror of https://github.com/open-mmlab/mmdeploy.git synced 2025-01-14 08:09:43 +08:00
mmdeploy/tests/test_apis/test_torch2torchscript.py
q.yao 91a0a9af0b
[Fix] Fix unit test in dev1.x ()
* fix erros in ut

* use nms_ratated ops instead of ext

* fix ut

* fix superresolution

* update super_resolution
2023-03-08 14:32:50 +08:00

53 lines
1.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import os.path as osp
import tempfile
import pytest
from mmengine import Config
from mmdeploy.apis import torch2torchscript
from mmdeploy.utils import IR, Backend
from mmdeploy.utils.test import get_random_name
ts_file = tempfile.NamedTemporaryFile(suffix='.pt').name
input_name = get_random_name()
output_name = get_random_name()
def get_deploy_cfg(input_name, output_name):
return Config(
dict(
ir_config=dict(
type=IR.TORCHSCRIPT.value,
input_names=[input_name],
output_names=[output_name],
input_shape=None),
codebase_config=dict(type='mmedit', task='SuperResolution'),
backend_config=dict(type=Backend.TORCHSCRIPT.value)))
def get_model_cfg():
import mmengine
file = 'tests/test_codebase/test_mmedit/data/model.py'
model_cfg = mmengine.Config.fromfile(file)
return model_cfg
@pytest.mark.parametrize('input_name', [input_name])
@pytest.mark.parametrize('output_name', [output_name])
@pytest.mark.skipif(
not importlib.util.find_spec('mmedit'), reason='requires mmedit')
def test_torch2torchscript(input_name, output_name):
import numpy as np
deploy_cfg = get_deploy_cfg(input_name, output_name)
torch2torchscript(
np.random.randint(0, 255, (8, 8, 3)),
'',
ts_file,
deploy_cfg,
model_cfg=get_model_cfg(),
device='cpu')
assert osp.exists(ts_file)