1
0
mirror of https://github.com/open-mmlab/mmdeploy.git synced 2025-01-14 08:09:43 +08:00
q.yao 3a785f1223
[Refactor] Refactor codebase ()
* [WIP] Refactor v2.0 ()

* Refactor backend wrapper

* Refactor mmdet.inference

* Fix

* merge

* refactor utils

* Use deployer and deploy_model to manage pipeline

* Resolve comments

* Add a real inference api function

* rename wrappers

* Set execute to private method

* Rename deployer deploy_model

* Refactor task

* remove type hint

* lint

* Resolve comments

* resolve comments

* lint

* docstring

* [Fix]: Fix bugs in details in refactor branch ()

* [WIP] Refactor v2.0 ()

* Refactor backend wrapper

* Refactor mmdet.inference

* Fix

* merge

* refactor utils

* Use deployer and deploy_model to manage pipeline

* Resolve comments

* Add a real inference api function

* rename wrappers

* Set execute to private method

* Rename deployer deploy_model

* Refactor task

* remove type hint

* lint

* Resolve comments

* resolve comments

* lint

* docstring

* Fix errors

* lint

* resolve comments

* fix bugs

* conflict

* lint and typo

* Resolve comment

* refactor mmseg ()

* support mmseg

* fix docstring

* fix docstring

* [Refactor]: Get the count of backend files ()

* Fix backend files

* resolve comments

* lint

* Fix ncnn

* [Refactor]: Refactor folders of mmdet ()

* Move folders

* lint

* test object detection model

* lint

* reset changes

* fix openvino

* resolve comments

* __init__.py

* Fix path

* [Refactor]: move mmseg ()

* [Refactor]: Refactor mmedit ()

* feature mmedit

* edit2.0

* edit

* refactor mmedit

* fix __init__.py

* fix __init__

* fix formai

* fix comment

* fix comment

* Fix wrong func_name of ConvFCBBoxHead ()

* [Refactor]: Refactor mmdet unit test ()

* Move folders

* lint

* test object detection model

* lint

* WIP

* remove print

* finish unit test

* Fix tests

* resolve comments

* Add mask test

* lint

* resolve comments

* Refine cfg file

* Move files

* add files

* Fix path

* [Unittest]: Refine the unit tests in mmdet 

* [Refactor] refactor mmocr to mmdeploy/codebase ()

* refactor mmocr to mmdeploy/codebase

* fix docstring of show_result

* fix docstring of visualize

* refine docstring

* replace print with logging

* refince codes

* resolve comments

* resolve comments

* [Refactor]: mmseg  tests ()

* refactor mmseg tests

* rename test_codebase

* update

* add model.py

* fix

* [Refactor] Refactor mmcls and the package ()

* refactor mmcls

* fix yapf

* fix isort

* refactor-mmcls-package

* fix print to logging

* fix docstrings according to others comments

* fix comments

* fix comments

* fix allentdans comment in pr215

* remove mmocr init

* [Refactor] Refactor mmedit tests ()

* feature mmedit

* edit2.0

* edit

* refactor mmedit

* fix __init__.py

* fix __init__

* fix formai

* fix comment

* fix comment

* buff

* edit test and code refactor

* refactor dir

* refactor tests/mmedit

* fix docstring

* add test coverage

* fix lint

* fix comment

* fix comment

* Update typehint ()

* update type hint

* update docstring

* update

* remove file

* fix ppl

* Refine get_predefined_partition_cfg

* fix tensorrt version > 8

* move parse_cuda_device_id to device.py

* Fix cascade

* onnx2ncnn docstring

Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: RunningLeon <maningsheng@sensetime.com>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
2021-11-25 09:57:05 +08:00

138 lines
4.7 KiB
Python

import importlib
import os.path as osp
from tempfile import NamedTemporaryFile
import mmcv
import numpy as np
import pytest
import torch
import mmdeploy.backend.onnxruntime as ort_apis
from mmdeploy.utils import Backend
from mmdeploy.utils.test import SwitchBackendWrapper
NUM_CLASS = 19
IMAGE_SIZE = 32
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
class TestEnd2EndModel:
@classmethod
def setup_class(cls):
# force add backend wrapper regardless of plugins
from mmdeploy.backend.onnxruntime import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
cls.wrapper = SwitchBackendWrapper(ORTWrapper)
cls.outputs = {
'outputs': torch.rand(1, 1, IMAGE_SIZE, IMAGE_SIZE),
}
cls.wrapper.set(outputs=cls.outputs)
deploy_cfg = mmcv.Config(
{'onnx_config': {
'output_names': ['outputs']
}})
from mmdeploy.codebase.mmseg.deploy.segmentation_model \
import End2EndModel
class_names = ['' for i in range(NUM_CLASS)]
palette = np.random.randint(0, 255, size=(NUM_CLASS, 3))
cls.end2end_model = End2EndModel(
Backend.ONNXRUNTIME, [''],
device='cpu',
class_names=class_names,
palette=palette,
deploy_cfg=deploy_cfg)
@classmethod
def teardown_class(cls):
cls.wrapper.recover()
@pytest.mark.parametrize(
'ori_shape',
[[IMAGE_SIZE, IMAGE_SIZE, 3], [2 * IMAGE_SIZE, 2 * IMAGE_SIZE, 3]])
def test_forward(self, ori_shape):
imgs = [torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE)]
img_metas = [[{
'ori_shape': ori_shape,
'img_shape': [IMAGE_SIZE, IMAGE_SIZE, 3],
'scale_factor': [1., 1., 1., 1.],
}]]
results = self.end2end_model.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '\
'End2EndModel'
def test_forward_test(self):
imgs = torch.rand(2, 3, IMAGE_SIZE, IMAGE_SIZE)
results = self.end2end_model.forward_test(imgs)
assert isinstance(results[0], np.ndarray)
def test_show_result(self):
input_img = np.zeros([IMAGE_SIZE, IMAGE_SIZE, 3])
img_path = NamedTemporaryFile(suffix='.jpg').name
result = [torch.rand(IMAGE_SIZE, IMAGE_SIZE)]
self.end2end_model.show_result(
input_img, result, '', show=False, out_file=img_path)
assert osp.exists(img_path), 'Fails to create drawn image.'
@pytest.mark.parametrize('from_file', [True, False])
@pytest.mark.parametrize('data_type', ['train', 'val', 'test'])
def test_get_classes_palette_from_config(from_file, data_type):
from mmseg.datasets import DATASETS
from mmdeploy.codebase.mmseg.deploy.segmentation_model \
import get_classes_palette_from_config
dataset_type = 'CityscapesDataset'
data_cfg = mmcv.Config({
'data': {
data_type:
dict(
type=dataset_type,
data_root='',
img_dir='',
ann_dir='',
pipeline=None)
}
})
if from_file:
config_path = NamedTemporaryFile(suffix='.py').name
with open(config_path, 'w') as file:
file.write(data_cfg.pretty_text)
data_cfg = config_path
classes, palette = get_classes_palette_from_config(data_cfg)
module = DATASETS.module_dict[dataset_type]
assert classes == module.CLASSES, \
f'fail to get CLASSES of dataset: {dataset_type}'
assert palette == module.PALETTE, \
f'fail to get PALETTE of dataset: {dataset_type}'
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_build_segmentation_model():
model_cfg = mmcv.Config(
dict(data=dict(test={'type': 'CityscapesDataset'})))
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
onnx_config=dict(output_names=['outputs']),
codebase_config=dict(type='mmseg')))
from mmdeploy.backend.onnxruntime import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
with SwitchBackendWrapper(ORTWrapper) as wrapper:
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
from mmdeploy.codebase.mmseg.deploy.segmentation_model import \
build_segmentation_model, End2EndModel
segmentor = build_segmentation_model([''], model_cfg, deploy_cfg,
'cpu')
assert isinstance(segmentor, End2EndModel)