mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* Torchscript support (#159) * support torchscript * add nms * add torchscript configs and update deploy process and dump-info * typescript -> torchscript * add torchscript custom extension support * add ts custom ops again * support mmseg unet * [WIP] add optimizer for torchscript (#119) * add passes * add python api * Torchscript optimizer python api (#121) * add passes * add python api * use python api instead of executable * Merge Master, update optimizer (#151) * [Feature] add yolox ncnn (#29) * add yolox ncnn * add ncnn android performance of yolox * add ut * fix lint * fix None bugs for ncnn * test codecov * test codecov * add device * fix yapf * remove if-else for img shape * use channelshuffle optimize * change benchmark after channelshuffle * fix yapf * fix yapf * fuse continuous reshape * fix static shape deploy * fix code * drop pad * only static shape * fix static * fix docstring * Added mask overlay to output image, changed fprintf info messages to … (#55) * Added mask overlay to output image, changed fprintf info messages to stdout * Improved box filtering (filter area/score), make sure roi coordinates stay within bounds * clang-format * Support UNet in mmseg (#77) * Repeatdataset in train has no CLASSES & PALETTE * update result for unet * update docstring for mmdet * remove ppl for unet in docs * fix ort wrap about input type (#81) * Fix memleak (#86) * delete [] * fix build error when enble MMDEPLOY_ACTIVE_LEVEL * fix lint * [Doc] Nano benchmark and tutorial (#71) * add cls benchmark * add nano zh-cn benchmark and en tutorial * add device row * add doc path to index.rst * fix typo * [Fix] fix missing deploy_core (#80) * fix missing deploy_core * mv flag to demo * target link * [Docs] Fix links in Chinese doc (#84) * Fix docs in Chinese link * Fix links * Delete symbolic link and add links to html * delete files * Fix link * [Feature] Add docker files (#67) * add gpu and cpu dockerfile * fix lint * fix cpu docker and remove redundant * use pip instead * add build arg and readme * fix grammar * update readme * add chinese doc for dockerfile and add docker build to build.md * grammar * refine dockerfiles * add FAQs * update Dpplcv_DIR for SDK building * remove mmcls * add sdk demos * fix typo and lint * update FAQs * [Fix]fix check_env (#101) * fix check_env * update * Replace convert_syncbatchnorm in mmseg (#93) * replace convert_syncbatchnorm with revert_sync_batchnorm from mmcv * change logger * [Doc] Update FAQ for TensorRT (#96) * update FAQ * comment * [Docs]: Update doc for openvino installation (#102) * fix docs * fix docs * fix docs * fix mmcv version * fix docs * rm blank line * simplify non batch nms (#99) * [Enhacement] Allow test.py to save evaluation results (#108) * Add log file * Delete debug code * Rename logger * resolve comments * [Enhancement] Support mmocr v0.4+ (#115) * support mmocr v0.4+ * 0.4.0 -> 0.4.1 * fix onnxruntime wrapper for gpu inference (#123) * fix ncnn wrapper for ort-gpu * resolve comment * fix lint * Fix typo (#132) * lock mmcls version (#131) * [Enhancement] upgrade isort in pre-commit config (#141) * [Enhancement] upgrade isort in pre-commit config by refering to mmflow pr #87 * fix lint * remove .isort.cfg and put its known_third_party to setup.cfg * Fix ci for mmocr (#144) * fix mmocr unittests * remove useless * lock mmdet maximum version to 2.20 * pip install -U numpy * Fix capture_output (#125) Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com> Co-authored-by: RunningLeon <mnsheng@yeah.net> Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: Yifan Zhou <singlezombie@163.com> Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com> Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com> * configs for all tasks * use torchvision roi align * remote unnecessary code * fix ut * fix ut * export * det dynamic * det dynamic * add ut * fix ut * add ut and docs * fix ut * skip torchscript ut if no ops available * add torchscript option to build.md * update benchmark and resolve comments * resolve conflicts * rename configs * fix mrcnn cuda test * remove useless * add version requirements to docs and comments to codes * enable empty image exporting for torchscript and accelerate ORT inference for MRCNN * rebase * update example for torchscript.md * update FAQs for torchscript.md * resolve comments * only use torchvision roi_align for torchscript * fix ut * use torchvision roi align when pool model is avg * resolve comments Co-authored-by: grimoire <streetyao@live.com> Co-authored-by: grimoire <yaoqian@sensetime.com> Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com> Co-authored-by: RunningLeon <mnsheng@yeah.net> Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: Yifan Zhou <singlezombie@163.com> Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com> Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com> * remove roi_align plugin for ORT (#258) * remove roi_align plugin * remove ut * skip single_roi_extractor UT for ORT in CI * move align to symbolic and update docs * recover UT * resolve comments * add mmcls example * add mmcls/mmdet/mmseg and their corresponding tests * add test data * simplify test data * add requirement in optional.txt * fix setup problem when adding mmrazor requirement * use get_codebase_config * change mmrazor requirement Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: grimoire <streetyao@live.com> Co-authored-by: grimoire <yaoqian@sensetime.com> Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com> Co-authored-by: RunningLeon <mnsheng@yeah.net> Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: Yifan Zhou <singlezombie@163.com> Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com> Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>
146 lines
4.9 KiB
Python
146 lines
4.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os
|
|
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
|
from typing import Any
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
import mmdeploy.backend.onnxruntime as ort_apis
|
|
from mmdeploy.apis import build_task_processor
|
|
from mmdeploy.codebase import import_codebase
|
|
from mmdeploy.utils import Codebase, load_config
|
|
from mmdeploy.utils.test import DummyModel, SwitchBackendWrapper
|
|
|
|
import_codebase(Codebase.MMCLS)
|
|
|
|
model_cfg_path = 'tests/test_codebase/test_mmcls/data/model.py'
|
|
model_cfg = load_config(model_cfg_path)[0]
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type='onnxruntime'),
|
|
codebase_config=dict(type='mmcls', task='Classification'),
|
|
onnx_config=dict(
|
|
type='onnx',
|
|
export_params=True,
|
|
keep_initializers_as_inputs=False,
|
|
opset_version=11,
|
|
input_shape=None,
|
|
input_names=['input'],
|
|
output_names=['output'])))
|
|
|
|
onnx_file = NamedTemporaryFile(suffix='.onnx').name
|
|
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
|
img_shape = (64, 64)
|
|
num_classes = 1000
|
|
img = np.random.rand(*img_shape, 3)
|
|
|
|
|
|
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
|
|
def test_init_pytorch_model(from_mmrazor: Any):
|
|
from mmcls.models.classifiers.base import BaseClassifier
|
|
if from_mmrazor is False:
|
|
_task_processor = task_processor
|
|
else:
|
|
_model_cfg_path = 'tests/test_codebase/test_mmcls/data/' \
|
|
'mmrazor_model.py'
|
|
_model_cfg = load_config(_model_cfg_path)[0]
|
|
_model_cfg.algorithm.architecture.model.type = 'mmcls.ImageClassifier'
|
|
_model_cfg.algorithm.architecture.model.backbone = dict(
|
|
type='SearchableShuffleNetV2', widen_factor=1.0)
|
|
_deploy_cfg = copy.deepcopy(deploy_cfg)
|
|
_deploy_cfg.codebase_config['from_mmrazor'] = from_mmrazor
|
|
_task_processor = build_task_processor(_model_cfg, _deploy_cfg, 'cpu')
|
|
|
|
if not isinstance(from_mmrazor, bool):
|
|
with pytest.raises(
|
|
TypeError,
|
|
match='`from_mmrazor` attribute must be '
|
|
'boolean type! '
|
|
f'but got: {from_mmrazor}'):
|
|
_ = _task_processor.from_mmrazor
|
|
return
|
|
assert from_mmrazor == _task_processor.from_mmrazor
|
|
|
|
model = _task_processor.init_pytorch_model(None)
|
|
assert isinstance(model, BaseClassifier)
|
|
|
|
|
|
@pytest.fixture
|
|
def backend_model():
|
|
from mmdeploy.backend.onnxruntime import ORTWrapper
|
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
|
wrapper = SwitchBackendWrapper(ORTWrapper)
|
|
wrapper.set(outputs={
|
|
'output': torch.rand(1, num_classes),
|
|
})
|
|
|
|
yield task_processor.init_backend_model([''])
|
|
|
|
wrapper.recover()
|
|
|
|
|
|
def test_init_backend_model(backend_model):
|
|
assert isinstance(backend_model, torch.nn.Module)
|
|
|
|
|
|
def test_create_input():
|
|
inputs = task_processor.create_input(img, input_shape=img_shape)
|
|
assert isinstance(inputs, tuple) and len(inputs) == 2
|
|
|
|
|
|
def test_run_inference(backend_model):
|
|
input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
|
|
results = task_processor.run_inference(backend_model, input_dict)
|
|
assert results is not None
|
|
|
|
|
|
def test_visualize(backend_model):
|
|
input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
|
|
results = task_processor.run_inference(backend_model, input_dict)
|
|
with TemporaryDirectory() as dir:
|
|
filename = dir + 'tmp.jpg'
|
|
task_processor.visualize(backend_model, img, results[0], filename, '')
|
|
assert os.path.exists(filename)
|
|
|
|
|
|
def test_get_tensor_from_input():
|
|
input_data = {'img': torch.ones(3, 4, 5)}
|
|
inputs = task_processor.get_tensor_from_input(input_data)
|
|
assert torch.equal(inputs, torch.ones(3, 4, 5))
|
|
|
|
|
|
def test_get_partition_cfg():
|
|
try:
|
|
_ = task_processor.get_partition_cfg(partition_type='')
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
|
|
def test_build_dataset_and_dataloader():
|
|
from torch.utils.data import DataLoader, Dataset
|
|
dataset = task_processor.build_dataset(
|
|
dataset_cfg=model_cfg, dataset_type='test')
|
|
assert isinstance(dataset, Dataset), 'Failed to build dataset'
|
|
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
|
assert isinstance(dataloader, DataLoader), 'Failed to build dataloader'
|
|
|
|
|
|
def test_single_gpu_test_and_evaluate():
|
|
from mmcv.parallel import MMDataParallel
|
|
dataset = task_processor.build_dataset(
|
|
dataset_cfg=model_cfg, dataset_type='test')
|
|
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
|
|
|
# Prepare dummy model
|
|
model = DummyModel(outputs=[torch.rand([1, 1000])])
|
|
model = MMDataParallel(model, device_ids=[0])
|
|
assert model is not None
|
|
# Run test
|
|
outputs = task_processor.single_gpu_test(model, dataloader)
|
|
assert outputs is not None
|
|
task_processor.evaluate_outputs(model_cfg, outputs, dataset)
|