mmdeploy/tests/test_apis/test_torch2torchscript.py
AllentDan 4fc8828af8 fix ocr UT
fix ci and lint

fix det

fix cuda ci

fix mmdet test

update object detection

fix ut

fix layer norm ut

update ut

lock mmeit version

fix mmocr mmcls ut

add conftest.py

fix ocr ut

fix mmedit ci

install mmedit from source

fix rknn model and prepare_onnx_paddings__tensorrt UT

docstring

fix coreml export

update mmocr config

small test

recovery assert

fix ci
2022-09-29 16:37:36 +08:00

54 lines
1.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import os.path as osp
import tempfile
import pytest
from mmengine import Config
from mmdeploy.apis import torch2torchscript
from mmdeploy.utils import IR, Backend
from mmdeploy.utils.test import get_random_name
ts_file = tempfile.NamedTemporaryFile(suffix='.pt').name
input_name = get_random_name()
output_name = get_random_name()
def get_deploy_cfg(input_name, output_name):
return Config(
dict(
ir_config=dict(
type=IR.TORCHSCRIPT.value,
input_names=[input_name],
output_names=[output_name],
input_shape=None),
codebase_config=dict(type='mmedit', task='SuperResolution'),
backend_config=dict(type=Backend.TORCHSCRIPT.value)))
def get_model_cfg():
import mmengine
file = 'tests/test_codebase/test_mmedit/data/model.py'
model_cfg = mmengine.Config.fromfile(file)
return model_cfg
@pytest.mark.parametrize('input_name', [input_name])
@pytest.mark.parametrize('output_name', [output_name])
@pytest.mark.skipif(
not importlib.util.find_spec('mmedit'), reason='requires mmedit')
def test_torch2torchscript(input_name, output_name):
import numpy as np
deploy_cfg = get_deploy_cfg(input_name, output_name)
torch2torchscript(
np.random.rand(8, 8, 3),
'',
ts_file,
deploy_cfg,
model_cfg=get_model_cfg(),
device='cpu')
print(ts_file)
assert osp.exists(ts_file)