mmdeploy/tests/test_apis/test_torch2torchscript.py
huayuan4396 5e9d27b8d6
mmedit -> mmagic (#2061)
* mmedit -> mmagic --initial

* fix codebase/cmakelist

* add tests/test_codebase/test_mmagic/data/

* fix lint

* fix rename

* fix EditDataPreprocessor

* fix EditTestLoop to TestLoop for mmagic

* fix EditValLoop to ValLoop for mmagic

* fix EditEvaluator to Evaluator for mmagic

* modify rgtest/mmagic.yml

* fix to MultiEvaluator

* fix mmagic model.py

* fix reg_test

* fix lint

* pass rgtest

* fix ci quantize.yml

* fix ci

* update docs

* fix lint

* fix lint

* fix lint

* fix sr end2endmodel device

* change destruct device back to cpu

* modify output device

* rename function name

* update docstring
2023-05-19 15:00:45 +08:00

53 lines
1.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import os.path as osp
import tempfile
import pytest
from mmengine import Config
from mmdeploy.apis import torch2torchscript
from mmdeploy.utils import IR, Backend
from mmdeploy.utils.test import get_random_name
ts_file = tempfile.NamedTemporaryFile(suffix='.pt').name
input_name = get_random_name()
output_name = get_random_name()
def get_deploy_cfg(input_name, output_name):
return Config(
dict(
ir_config=dict(
type=IR.TORCHSCRIPT.value,
input_names=[input_name],
output_names=[output_name],
input_shape=None),
codebase_config=dict(type='mmagic', task='SuperResolution'),
backend_config=dict(type=Backend.TORCHSCRIPT.value)))
def get_model_cfg():
import mmengine
file = 'tests/test_codebase/test_mmagic/data/model.py'
model_cfg = mmengine.Config.fromfile(file)
return model_cfg
@pytest.mark.parametrize('input_name', [input_name])
@pytest.mark.parametrize('output_name', [output_name])
@pytest.mark.skipif(
not importlib.util.find_spec('mmagic'), reason='requires mmagic')
def test_torch2torchscript(input_name, output_name):
import numpy as np
deploy_cfg = get_deploy_cfg(input_name, output_name)
torch2torchscript(
np.random.randint(0, 255, (8, 8, 3)),
'',
ts_file,
deploy_cfg,
model_cfg=get_model_cfg(),
device='cpu')
assert osp.exists(ts_file)