mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Unittest]: MMEdit unittest (#124)
* add mmedit test * Solve unittest bug * lint * Remove debug code * add data * Refine code * Fix ci * Fix ci * follow changes in mmdet test * try to remove pytest skip * remove redundant code * rename * Fix type hint * Fix lint * Refine SwitchBackendWrapper. Fix type hint * update docstring
This commit is contained in:
parent
07cb78bb7c
commit
10c4ef4203
@ -57,7 +57,7 @@ def init_pytorch_model(codebase: Codebase,
|
|||||||
def create_input(codebase: Codebase,
|
def create_input(codebase: Codebase,
|
||||||
task: Task,
|
task: Task,
|
||||||
model_cfg: Union[str, mmcv.Config],
|
model_cfg: Union[str, mmcv.Config],
|
||||||
imgs: Any,
|
imgs: Union[str, np.ndarray],
|
||||||
input_shape: Sequence[int] = None,
|
input_shape: Sequence[int] = None,
|
||||||
device: str = 'cuda:0',
|
device: str = 'cuda:0',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -68,8 +68,8 @@ def create_input(codebase: Codebase,
|
|||||||
task (Task): Specifying task type.
|
task (Task): Specifying task type.
|
||||||
model_cfg (str | mmcv.Config): Model config file or loaded Config
|
model_cfg (str | mmcv.Config): Model config file or loaded Config
|
||||||
object.
|
object.
|
||||||
imgs (Any): Input image(s), accpeted data type are `str`,
|
imgs (str | np.ndarray): Input image(s), accpeted data types are `str`,
|
||||||
`np.ndarray`, `torch.Tensor`.
|
`np.ndarray`.
|
||||||
input_shape (list[int]): Input shape of image in (width, height)
|
input_shape (list[int]): Input shape of image in (width, height)
|
||||||
format, defaults to `None`.
|
format, defaults to `None`.
|
||||||
device (str): A string specifying device type, defaults to 'cuda:0'.
|
device (str): A string specifying device type, defaults to 'cuda:0'.
|
||||||
|
@ -11,32 +11,65 @@ from torch.utils.data.dataset import Dataset
|
|||||||
from mmdeploy.utils import Task, load_config
|
from mmdeploy.utils import Task, load_config
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_cfg(config: Union[str, mmcv.Config]):
|
def _preprocess_cfg(config: Union[str, mmcv.Config], task: Task,
|
||||||
|
load_from_file: bool, is_static_cfg: bool,
|
||||||
|
input_shape: Sequence[int]):
|
||||||
"""Remove unnecessary information in config.
|
"""Remove unnecessary information in config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_cfg (str | mmcv.Config): The input model config.
|
model_cfg (str | mmcv.Config): The input model config.
|
||||||
|
task (Task): Specifying editing task type.
|
||||||
|
load_from_file (bool): Whether the input is a filename of a numpy
|
||||||
|
matrix. If this variable is True, extra preprocessing is required.
|
||||||
|
is_static_cfg (bool): Whether the config specifys a static export.
|
||||||
|
If this variable if True, the input image will be resize to a fix
|
||||||
|
resolution.
|
||||||
|
input_shape (Sequence[int]): A list of two integer in (width, height)
|
||||||
|
format specifying input shape. Defaults to `None`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Differentiate the editing tasks (e.g. restorers and mattors
|
# TODO: Differentiate the editing tasks (e.g. restorers and mattors
|
||||||
# preprocess the data in differenet ways)
|
# preprocess the data in differenet ways)
|
||||||
|
|
||||||
keys_to_remove = ['gt', 'gt_path']
|
if task == Task.SUPER_RESOLUTION:
|
||||||
|
keys_to_remove = ['gt', 'gt_path']
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown task type: {task.value}')
|
||||||
|
|
||||||
|
# MMEdit doesn't support LoadImageFromWebcam.
|
||||||
|
# Remove "LoadImageFromFile" and related metakeys.
|
||||||
|
if not load_from_file:
|
||||||
|
config.test_pipeline.pop(0)
|
||||||
|
if task == Task.SUPER_RESOLUTION:
|
||||||
|
keys_to_remove.append('lq_path')
|
||||||
|
|
||||||
|
# Fix the input shape by 'Resize'
|
||||||
|
if is_static_cfg:
|
||||||
|
if task == Task.SUPER_RESOLUTION:
|
||||||
|
resize = {
|
||||||
|
'type': 'Resize',
|
||||||
|
'scale': (input_shape[0], input_shape[1]),
|
||||||
|
'keys': ['lq']
|
||||||
|
}
|
||||||
|
config.test_pipeline.insert(1, resize)
|
||||||
|
|
||||||
for key in keys_to_remove:
|
for key in keys_to_remove:
|
||||||
for pipeline in list(config.test_pipeline):
|
for pipeline in list(config.test_pipeline):
|
||||||
if 'key' in pipeline and key == pipeline['key']:
|
if 'key' in pipeline and key == pipeline['key']:
|
||||||
config.test_pipeline.remove(pipeline)
|
config.test_pipeline.remove(pipeline)
|
||||||
if 'keys' in pipeline and key in pipeline['keys']:
|
if 'keys' in pipeline:
|
||||||
pipeline['keys'].remove(key)
|
while key in pipeline['keys']:
|
||||||
|
pipeline['keys'].remove(key)
|
||||||
if len(pipeline['keys']) == 0:
|
if len(pipeline['keys']) == 0:
|
||||||
config.test_pipeline.remove(pipeline)
|
config.test_pipeline.remove(pipeline)
|
||||||
if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
|
if 'meta_keys' in pipeline:
|
||||||
pipeline['meta_keys'].remove(key)
|
while key in pipeline['meta_keys']:
|
||||||
|
pipeline['meta_keys'].remove(key)
|
||||||
|
|
||||||
|
|
||||||
def create_input(task: Task,
|
def create_input(task: Task,
|
||||||
model_cfg: Union[str, mmcv.Config],
|
model_cfg: Union[str, mmcv.Config],
|
||||||
imgs: Union[str, mmcv.Config],
|
imgs: Union[str, np.ndarray],
|
||||||
input_shape: Optional[Sequence[int]] = None,
|
input_shape: Optional[Sequence[int]] = None,
|
||||||
device: Optional[str] = 'cuda:0'):
|
device: Optional[str] = 'cuda:0'):
|
||||||
"""Create input for editing processor.
|
"""Create input for editing processor.
|
||||||
@ -61,38 +94,30 @@ def create_input(task: Task,
|
|||||||
raise AssertionError('imgs must be strings or numpy arrays')
|
raise AssertionError('imgs must be strings or numpy arrays')
|
||||||
|
|
||||||
cfg = load_config(model_cfg)[0].copy()
|
cfg = load_config(model_cfg)[0].copy()
|
||||||
_preprocess_cfg(cfg)
|
|
||||||
|
|
||||||
if isinstance(imgs[0], np.ndarray):
|
_preprocess_cfg(
|
||||||
cfg = cfg.copy()
|
cfg,
|
||||||
# set loading pipeline type
|
task=task,
|
||||||
cfg.test_pipeline[0].type = 'LoadImageFromWebcam'
|
load_from_file=isinstance(imgs[0], str),
|
||||||
|
is_static_cfg=input_shape is not None,
|
||||||
# for static exporting
|
input_shape=input_shape)
|
||||||
if input_shape is not None:
|
|
||||||
if task == Task.SUPER_RESOLUTION:
|
|
||||||
resize = {
|
|
||||||
'type': 'Resize',
|
|
||||||
'scale': (input_shape[0], input_shape[1]),
|
|
||||||
'keys': ['lq']
|
|
||||||
}
|
|
||||||
cfg.test_pipeline.insert(1, resize)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown task type: {task.value}')
|
|
||||||
|
|
||||||
test_pipeline = Compose(cfg.test_pipeline)
|
test_pipeline = Compose(cfg.test_pipeline)
|
||||||
|
|
||||||
data_arr = []
|
data_arr = []
|
||||||
for img in imgs:
|
for img in imgs:
|
||||||
# TODO: This is only for restore. Add condiction statement
|
# TODO: This is only for restore. Add condiction statement.
|
||||||
data = dict(lq_path=img)
|
if isinstance(img, np.ndarray):
|
||||||
|
data = dict(lq=img)
|
||||||
|
else:
|
||||||
|
data = dict(lq_path=img)
|
||||||
|
|
||||||
data = test_pipeline(data)
|
data = test_pipeline(data)
|
||||||
data_arr.append(data)
|
data_arr.append(data)
|
||||||
|
|
||||||
data = collate(data_arr, samples_per_gpu=len(imgs))
|
data = collate(data_arr, samples_per_gpu=len(imgs))
|
||||||
|
|
||||||
# TODO: This is only for restore. Add condiction statement
|
# TODO: This is only for restore. Add condiction statement.
|
||||||
data['img'] = data['lq']
|
data['img'] = data['lq']
|
||||||
|
|
||||||
if device != 'cpu':
|
if device != 'cpu':
|
||||||
|
@ -59,13 +59,14 @@ class WrapModel(nn.Module):
|
|||||||
|
|
||||||
class SwitchBackendWrapper:
|
class SwitchBackendWrapper:
|
||||||
"""A switcher for backend wrapper for unit tests.
|
"""A switcher for backend wrapper for unit tests.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mmdeploy.utils.test import SwitchBackendWrapper
|
>>> from mmdeploy.utils.test import SwitchBackendWrapper
|
||||||
>>> from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
>>> from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
>>> SwitchBackendWrapper.set(ORTWrapper, outputs=outputs)
|
>>> with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
>>> wrapper.set(ORTWrapper, outputs=outputs)
|
||||||
|
>>> ...
|
||||||
|
>>> # ORTWrapper will recover when exiting context
|
||||||
>>> ...
|
>>> ...
|
||||||
>>> SwitchBackendWrapper.recover(ORTWrapper)
|
|
||||||
"""
|
"""
|
||||||
init = None
|
init = None
|
||||||
forward = None
|
forward = None
|
||||||
@ -83,26 +84,35 @@ class SwitchBackendWrapper:
|
|||||||
def __call__(self, *args, **kwds):
|
def __call__(self, *args, **kwds):
|
||||||
return self.forward(*args, **kwds)
|
return self.forward(*args, **kwds)
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self, recover_class):
|
||||||
def set(obj, **kwargs):
|
self._recover_class = recover_class
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, trace):
|
||||||
|
self.recover()
|
||||||
|
|
||||||
|
def set(self, **kwargs):
|
||||||
"""Replace attributes in backend wrappers with dummy items."""
|
"""Replace attributes in backend wrappers with dummy items."""
|
||||||
SwitchBackendWrapper.init = obj.__init__
|
obj = self._recover_class
|
||||||
SwitchBackendWrapper.forward = obj.forward
|
self.init = obj.__init__
|
||||||
SwitchBackendWrapper.call = obj.__call__
|
self.forward = obj.forward
|
||||||
|
self.call = obj.__call__
|
||||||
obj.__init__ = SwitchBackendWrapper.BackendWrapper.__init__
|
obj.__init__ = SwitchBackendWrapper.BackendWrapper.__init__
|
||||||
obj.forward = SwitchBackendWrapper.BackendWrapper.forward
|
obj.forward = SwitchBackendWrapper.BackendWrapper.forward
|
||||||
obj.__call__ = SwitchBackendWrapper.BackendWrapper.__call__
|
obj.__call__ = SwitchBackendWrapper.BackendWrapper.__call__
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
setattr(obj, k, v)
|
setattr(obj, k, v)
|
||||||
|
|
||||||
@staticmethod
|
def recover(self):
|
||||||
def recover(obj):
|
assert self.init is not None and \
|
||||||
assert SwitchBackendWrapper.init is not None and \
|
self.forward is not None,\
|
||||||
SwitchBackendWrapper.forward is not None,\
|
|
||||||
'recover method must be called after exchange'
|
'recover method must be called after exchange'
|
||||||
obj.__init__ = SwitchBackendWrapper.init
|
obj = self._recover_class
|
||||||
obj.forward = SwitchBackendWrapper.forward
|
obj.__init__ = self.init
|
||||||
obj.__call__ = SwitchBackendWrapper.call
|
obj.forward = self.forward
|
||||||
|
obj.__call__ = self.call
|
||||||
|
|
||||||
|
|
||||||
def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
|
def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import mmcv
|
import mmcv
|
||||||
@ -7,7 +8,6 @@ import mmcv
|
|||||||
from mmdeploy.apis import create_calib_table
|
from mmdeploy.apis import create_calib_table
|
||||||
|
|
||||||
calib_file = tempfile.NamedTemporaryFile(suffix='.h5').name
|
calib_file = tempfile.NamedTemporaryFile(suffix='.h5').name
|
||||||
data_prefix = 'tests/data/tiger'
|
|
||||||
ann_file = 'tests/data/annotation.json'
|
ann_file = 'tests/data/annotation.json'
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ def get_model_cfg():
|
|||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='MultiScaleFlipAug',
|
type='MultiScaleFlipAug',
|
||||||
img_scale=(1333, 800),
|
img_scale=(1, 1),
|
||||||
flip=False,
|
flip=False,
|
||||||
transforms=[
|
transforms=[
|
||||||
dict(type='Resize', keep_ratio=True),
|
dict(type='Resize', keep_ratio=True),
|
||||||
@ -169,7 +169,7 @@ def get_model_cfg():
|
|||||||
return model_cfg
|
return model_cfg
|
||||||
|
|
||||||
|
|
||||||
def test_create_calib_end2end():
|
def run_test_create_calib_end2end():
|
||||||
model_cfg = get_model_cfg()
|
model_cfg = get_model_cfg()
|
||||||
deploy_cfg = get_end2end_deploy_cfg()
|
deploy_cfg = get_end2end_deploy_cfg()
|
||||||
create_calib_table(
|
create_calib_table(
|
||||||
@ -189,7 +189,19 @@ def test_create_calib_end2end():
|
|||||||
assert calibrator['calib_data']['end2end']['input']['0'] is not None
|
assert calibrator['calib_data']['end2end']['input']['0'] is not None
|
||||||
|
|
||||||
|
|
||||||
def test_create_calib_parittion():
|
# Because Faster-RCNN needs too much memory on GPU, we need to run tests in a
|
||||||
|
# new process.
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_calib_end2end():
|
||||||
|
p = Process(target=run_test_create_calib_end2end)
|
||||||
|
try:
|
||||||
|
p.start()
|
||||||
|
finally:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
|
||||||
|
def run_test_create_calib_parittion():
|
||||||
model_cfg = get_model_cfg()
|
model_cfg = get_model_cfg()
|
||||||
deploy_cfg = get_partition_deploy_cfg()
|
deploy_cfg = get_partition_deploy_cfg()
|
||||||
create_calib_table(
|
create_calib_table(
|
||||||
@ -211,3 +223,11 @@ def test_create_calib_parittion():
|
|||||||
assert calib_data[partition_name] is not None
|
assert calib_data[partition_name] is not None
|
||||||
assert calib_data[partition_name][input_names[i]] is not None
|
assert calib_data[partition_name][input_names[i]] is not None
|
||||||
assert calib_data[partition_name][input_names[i]]['0'] is not None
|
assert calib_data[partition_name][input_names[i]]['0'] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_calib_parittion():
|
||||||
|
p = Process(target=run_test_create_calib_parittion)
|
||||||
|
try:
|
||||||
|
p.start()
|
||||||
|
finally:
|
||||||
|
p.join()
|
||||||
|
@ -26,20 +26,21 @@ def test_TensorRTDetector():
|
|||||||
'dets': torch.rand(1, 100, 5).cuda(),
|
'dets': torch.rand(1, 100, 5).cuda(),
|
||||||
'labels': torch.rand(1, 100).cuda()
|
'labels': torch.rand(1, 100).cuda()
|
||||||
}
|
}
|
||||||
SwitchBackendWrapper.set(TRTWrapper, outputs=outputs)
|
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import TensorRTDetector
|
from mmdeploy.mmdet.apis.inference import TensorRTDetector
|
||||||
trt_detector = TensorRTDetector('', ['' for i in range(80)], 0)
|
trt_detector = TensorRTDetector('', ['' for i in range(80)], 0)
|
||||||
imgs = [torch.rand(1, 3, 64, 64).cuda()]
|
imgs = [torch.rand(1, 3, 64, 64).cuda()]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [64, 64, 3],
|
'ori_shape': [64, 64, 3],
|
||||||
'img_shape': [64, 64, 3],
|
'img_shape': [64, 64, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
|
|
||||||
results = trt_detector.forward(imgs, img_metas)
|
results = trt_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using TensorRTDetector'
|
assert results is not None, ('failed to get output using '
|
||||||
SwitchBackendWrapper.recover(TRTWrapper)
|
'TensorRTDetector')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -52,21 +53,21 @@ def test_ONNXRuntimeDetector():
|
|||||||
|
|
||||||
# simplify backend inference
|
# simplify backend inference
|
||||||
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
|
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
|
||||||
SwitchBackendWrapper.set(ORTWrapper, outputs=outputs)
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import ONNXRuntimeDetector
|
from mmdeploy.mmdet.apis.inference import ONNXRuntimeDetector
|
||||||
ort_detector = ONNXRuntimeDetector('', ['' for i in range(80)], 0)
|
ort_detector = ONNXRuntimeDetector('', ['' for i in range(80)], 0)
|
||||||
imgs = [torch.rand(1, 3, 64, 64)]
|
imgs = [torch.rand(1, 3, 64, 64)]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [64, 64, 3],
|
'ori_shape': [64, 64, 3],
|
||||||
'img_shape': [64, 64, 3],
|
'img_shape': [64, 64, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
|
|
||||||
results = ort_detector.forward(imgs, img_metas)
|
results = ort_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using '\
|
assert results is not None, 'failed to get output using '\
|
||||||
'ONNXRuntimeDetector'
|
'ONNXRuntimeDetector'
|
||||||
SwitchBackendWrapper.recover(ORTWrapper)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
@ -80,20 +81,20 @@ def test_PPLDetector():
|
|||||||
|
|
||||||
# simplify backend inference
|
# simplify backend inference
|
||||||
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
|
outputs = (torch.rand(1, 100, 5), torch.rand(1, 100))
|
||||||
SwitchBackendWrapper.set(PPLWrapper, outputs=outputs)
|
with SwitchBackendWrapper(PPLWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import PPLDetector
|
from mmdeploy.mmdet.apis.inference import PPLDetector
|
||||||
ppl_detector = PPLDetector('', ['' for i in range(80)], 0)
|
ppl_detector = PPLDetector('', ['' for i in range(80)], 0)
|
||||||
imgs = [torch.rand(1, 3, 64, 64)]
|
imgs = [torch.rand(1, 3, 64, 64)]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [64, 64, 3],
|
'ori_shape': [64, 64, 3],
|
||||||
'img_shape': [64, 64, 3],
|
'img_shape': [64, 64, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
|
|
||||||
results = ppl_detector.forward(imgs, img_metas)
|
results = ppl_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using PPLDetector'
|
assert results is not None, 'failed to get output using PPLDetector'
|
||||||
SwitchBackendWrapper.recover(PPLWrapper)
|
|
||||||
|
|
||||||
|
|
||||||
def get_test_cfg_and_post_processing():
|
def get_test_cfg_and_post_processing():
|
||||||
@ -155,28 +156,26 @@ def test_NCNNPSSDetector():
|
|||||||
'scores': torch.rand(1, 120, 80),
|
'scores': torch.rand(1, 120, 80),
|
||||||
'boxes': torch.rand(1, 120, 4)
|
'boxes': torch.rand(1, 120, 4)
|
||||||
}
|
}
|
||||||
SwitchBackendWrapper.set(
|
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
|
||||||
NCNNWrapper,
|
wrapper.set(
|
||||||
outputs=outputs,
|
outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
model_cfg=model_cfg,
|
|
||||||
deploy_cfg=deploy_cfg)
|
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import NCNNPSSDetector
|
from mmdeploy.mmdet.apis.inference import NCNNPSSDetector
|
||||||
|
|
||||||
ncnn_pss_detector = NCNNPSSDetector(['', ''], ['' for i in range(80)],
|
ncnn_pss_detector = NCNNPSSDetector(['', ''], ['' for i in range(80)],
|
||||||
model_cfg=model_cfg,
|
model_cfg=model_cfg,
|
||||||
deploy_cfg=deploy_cfg,
|
deploy_cfg=deploy_cfg,
|
||||||
device_id=0)
|
device_id=0)
|
||||||
imgs = [torch.rand(1, 3, 32, 32)]
|
imgs = [torch.rand(1, 3, 32, 32)]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [32, 32, 3],
|
'ori_shape': [32, 32, 3],
|
||||||
'img_shape': [32, 32, 3],
|
'img_shape': [32, 32, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
|
|
||||||
results = ncnn_pss_detector.forward(imgs, img_metas)
|
results = ncnn_pss_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using NCNNPSSDetector'
|
assert results is not None, ('failed to get output using '
|
||||||
SwitchBackendWrapper.recover(NCNNWrapper)
|
'NCNNPSSDetector')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -197,30 +196,27 @@ def test_ONNXRuntimePSSDetector():
|
|||||||
np.random.rand(1, 120, 80).astype(np.float32),
|
np.random.rand(1, 120, 80).astype(np.float32),
|
||||||
np.random.rand(1, 120, 4).astype(np.float32)
|
np.random.rand(1, 120, 4).astype(np.float32)
|
||||||
]
|
]
|
||||||
SwitchBackendWrapper.set(
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
ORTWrapper,
|
wrapper.set(
|
||||||
outputs=outputs,
|
outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
model_cfg=model_cfg,
|
|
||||||
deploy_cfg=deploy_cfg)
|
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import ONNXRuntimePSSDetector
|
from mmdeploy.mmdet.apis.inference import ONNXRuntimePSSDetector
|
||||||
|
|
||||||
ort_pss_detector = ONNXRuntimePSSDetector(
|
ort_pss_detector = ONNXRuntimePSSDetector(
|
||||||
'', ['' for i in range(80)],
|
'', ['' for i in range(80)],
|
||||||
model_cfg=model_cfg,
|
model_cfg=model_cfg,
|
||||||
deploy_cfg=deploy_cfg,
|
deploy_cfg=deploy_cfg,
|
||||||
device_id=0)
|
device_id=0)
|
||||||
imgs = [torch.rand(1, 3, 32, 32)]
|
imgs = [torch.rand(1, 3, 32, 32)]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [32, 32, 3],
|
'ori_shape': [32, 32, 3],
|
||||||
'img_shape': [32, 32, 3],
|
'img_shape': [32, 32, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
|
|
||||||
results = ort_pss_detector.forward(imgs, img_metas)
|
results = ort_pss_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using '
|
assert results is not None, 'failed to get output using '
|
||||||
'ONNXRuntimePSSDetector'
|
'ONNXRuntimePSSDetector'
|
||||||
SwitchBackendWrapper.recover(ORTWrapper)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
@ -242,30 +238,27 @@ def test_TensorRTPSSDetector():
|
|||||||
'scores': torch.rand(1, 120, 80).cuda(),
|
'scores': torch.rand(1, 120, 80).cuda(),
|
||||||
'boxes': torch.rand(1, 120, 4).cuda()
|
'boxes': torch.rand(1, 120, 4).cuda()
|
||||||
}
|
}
|
||||||
SwitchBackendWrapper.set(
|
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||||
TRTWrapper,
|
wrapper.set(
|
||||||
outputs=outputs,
|
outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
model_cfg=model_cfg,
|
|
||||||
deploy_cfg=deploy_cfg)
|
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import TensorRTPSSDetector
|
from mmdeploy.mmdet.apis.inference import TensorRTPSSDetector
|
||||||
|
|
||||||
trt_pss_detector = TensorRTPSSDetector(
|
trt_pss_detector = TensorRTPSSDetector(
|
||||||
'', ['' for i in range(80)],
|
'', ['' for i in range(80)],
|
||||||
model_cfg=model_cfg,
|
model_cfg=model_cfg,
|
||||||
deploy_cfg=deploy_cfg,
|
deploy_cfg=deploy_cfg,
|
||||||
device_id=0)
|
device_id=0)
|
||||||
imgs = [torch.rand(1, 3, 32, 32).cuda()]
|
imgs = [torch.rand(1, 3, 32, 32).cuda()]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [32, 32, 3],
|
'ori_shape': [32, 32, 3],
|
||||||
'img_shape': [32, 32, 3],
|
'img_shape': [32, 32, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
|
|
||||||
results = trt_pss_detector.forward(imgs, img_metas)
|
results = trt_pss_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using '
|
assert results is not None, 'failed to get output using '
|
||||||
'TensorRTPSSDetector'
|
'TensorRTPSSDetector'
|
||||||
SwitchBackendWrapper.recover(TRTWrapper)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_deploy_cfgs():
|
def prepare_model_deploy_cfgs():
|
||||||
@ -377,41 +370,41 @@ def test_TensorRTPTSDetector():
|
|||||||
'cls_score': torch.rand(1, 12, 80).cuda(),
|
'cls_score': torch.rand(1, 12, 80).cuda(),
|
||||||
'bbox_pred': torch.rand(1, 12, 4).cuda()
|
'bbox_pred': torch.rand(1, 12, 4).cuda()
|
||||||
}
|
}
|
||||||
SwitchBackendWrapper.set(TRTWrapper, outputs=outputs)
|
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||||
TRTWrapper.model_cfg = model_cfg
|
wrapper.set(
|
||||||
TRTWrapper.deploy_cfg = deploy_cfg
|
outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
# replace original function in PartitionTwoStageDetector
|
# replace original function in PartitionTwoStageDetector
|
||||||
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
||||||
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
||||||
PartitionTwoStageDetector.partition0_postprocess = \
|
PartitionTwoStageDetector.partition0_postprocess = \
|
||||||
DummyPTSDetector.partition0_postprocess
|
DummyPTSDetector.partition0_postprocess
|
||||||
PartitionTwoStageDetector.partition1_postprocess = \
|
PartitionTwoStageDetector.partition1_postprocess = \
|
||||||
DummyPTSDetector.partition1_postprocess
|
DummyPTSDetector.partition1_postprocess
|
||||||
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3).cuda()] * 2
|
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3).cuda()] * 2
|
||||||
PartitionTwoStageDetector.outputs1 = [
|
PartitionTwoStageDetector.outputs1 = [
|
||||||
torch.rand(1, 9, 5).cuda(),
|
torch.rand(1, 9, 5).cuda(),
|
||||||
torch.rand(1, 9).cuda()
|
torch.rand(1, 9).cuda()
|
||||||
]
|
]
|
||||||
PartitionTwoStageDetector.device_id = 0
|
PartitionTwoStageDetector.device_id = 0
|
||||||
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import TensorRTPTSDetector
|
from mmdeploy.mmdet.apis.inference import TensorRTPTSDetector
|
||||||
trt_pts_detector = TensorRTPTSDetector(['', ''], ['' for i in range(80)],
|
trt_pts_detector = TensorRTPTSDetector(['', ''],
|
||||||
model_cfg=model_cfg,
|
['' for i in range(80)],
|
||||||
deploy_cfg=deploy_cfg,
|
model_cfg=model_cfg,
|
||||||
device_id=0)
|
deploy_cfg=deploy_cfg,
|
||||||
|
device_id=0)
|
||||||
|
|
||||||
imgs = [torch.rand(1, 3, 32, 32).cuda()]
|
imgs = [torch.rand(1, 3, 32, 32).cuda()]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [32, 32, 3],
|
'ori_shape': [32, 32, 3],
|
||||||
'img_shape': [32, 32, 3],
|
'img_shape': [32, 32, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
results = trt_pts_detector.forward(imgs, img_metas)
|
results = trt_pts_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using '
|
assert results is not None, 'failed to get output using '
|
||||||
'TensorRTPTSDetector'
|
'TensorRTPTSDetector'
|
||||||
SwitchBackendWrapper.recover(TRTWrapper)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -429,43 +422,40 @@ def test_ONNXRuntimePTSDetector():
|
|||||||
np.random.rand(1, 12, 80).astype(np.float32),
|
np.random.rand(1, 12, 80).astype(np.float32),
|
||||||
np.random.rand(1, 12, 4).astype(np.float32),
|
np.random.rand(1, 12, 4).astype(np.float32),
|
||||||
] * 2
|
] * 2
|
||||||
SwitchBackendWrapper.set(
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
ORTWrapper,
|
wrapper.set(
|
||||||
outputs=outputs,
|
outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
model_cfg=model_cfg,
|
|
||||||
deploy_cfg=deploy_cfg)
|
|
||||||
|
|
||||||
# replace original function in PartitionTwoStageDetector
|
# replace original function in PartitionTwoStageDetector
|
||||||
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
||||||
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
||||||
PartitionTwoStageDetector.partition0_postprocess = \
|
PartitionTwoStageDetector.partition0_postprocess = \
|
||||||
DummyPTSDetector.partition0_postprocess
|
DummyPTSDetector.partition0_postprocess
|
||||||
PartitionTwoStageDetector.partition1_postprocess = \
|
PartitionTwoStageDetector.partition1_postprocess = \
|
||||||
DummyPTSDetector.partition1_postprocess
|
DummyPTSDetector.partition1_postprocess
|
||||||
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
|
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
|
||||||
PartitionTwoStageDetector.outputs1 = [
|
PartitionTwoStageDetector.outputs1 = [
|
||||||
torch.rand(1, 9, 5), torch.rand(1, 9)
|
torch.rand(1, 9, 5), torch.rand(1, 9)
|
||||||
]
|
]
|
||||||
PartitionTwoStageDetector.device_id = -1
|
PartitionTwoStageDetector.device_id = -1
|
||||||
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import ONNXRuntimePTSDetector
|
from mmdeploy.mmdet.apis.inference import ONNXRuntimePTSDetector
|
||||||
ort_pts_detector = ONNXRuntimePTSDetector(['', ''],
|
ort_pts_detector = ONNXRuntimePTSDetector(['', ''],
|
||||||
['' for i in range(80)],
|
['' for i in range(80)],
|
||||||
model_cfg=model_cfg,
|
model_cfg=model_cfg,
|
||||||
deploy_cfg=deploy_cfg,
|
deploy_cfg=deploy_cfg,
|
||||||
device_id=0)
|
device_id=0)
|
||||||
|
|
||||||
imgs = [torch.rand(1, 3, 32, 32)]
|
imgs = [torch.rand(1, 3, 32, 32)]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [32, 32, 3],
|
'ori_shape': [32, 32, 3],
|
||||||
'img_shape': [32, 32, 3],
|
'img_shape': [32, 32, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
results = ort_pts_detector.forward(imgs, img_metas)
|
results = ort_pts_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using '
|
assert results is not None, 'failed to get output using '
|
||||||
'ONNXRuntimePTSDetector'
|
'ONNXRuntimePTSDetector'
|
||||||
SwitchBackendWrapper.recover(ORTWrapper)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -487,43 +477,40 @@ def test_NCNNPTSDetector():
|
|||||||
'cls_score': torch.rand(1, 12, 80),
|
'cls_score': torch.rand(1, 12, 80),
|
||||||
'bbox_pred': torch.rand(1, 12, 4)
|
'bbox_pred': torch.rand(1, 12, 4)
|
||||||
}
|
}
|
||||||
SwitchBackendWrapper.set(
|
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
|
||||||
NCNNWrapper,
|
wrapper.set(
|
||||||
outputs=outputs,
|
outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
model_cfg=model_cfg,
|
|
||||||
deploy_cfg=deploy_cfg)
|
|
||||||
|
|
||||||
# replace original function in PartitionTwoStageDetector
|
# replace original function in PartitionTwoStageDetector
|
||||||
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector
|
||||||
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__
|
||||||
PartitionTwoStageDetector.partition0_postprocess = \
|
PartitionTwoStageDetector.partition0_postprocess = \
|
||||||
DummyPTSDetector.partition0_postprocess
|
DummyPTSDetector.partition0_postprocess
|
||||||
PartitionTwoStageDetector.partition1_postprocess = \
|
PartitionTwoStageDetector.partition1_postprocess = \
|
||||||
DummyPTSDetector.partition1_postprocess
|
DummyPTSDetector.partition1_postprocess
|
||||||
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
|
PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2
|
||||||
PartitionTwoStageDetector.outputs1 = [
|
PartitionTwoStageDetector.outputs1 = [
|
||||||
torch.rand(1, 9, 5), torch.rand(1, 9)
|
torch.rand(1, 9, 5), torch.rand(1, 9)
|
||||||
]
|
]
|
||||||
PartitionTwoStageDetector.device_id = -1
|
PartitionTwoStageDetector.device_id = -1
|
||||||
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
PartitionTwoStageDetector.CLASSES = ['' for i in range(80)]
|
||||||
|
|
||||||
from mmdeploy.mmdet.apis.inference import NCNNPTSDetector
|
from mmdeploy.mmdet.apis.inference import NCNNPTSDetector
|
||||||
ncnn_pts_detector = NCNNPTSDetector(
|
ncnn_pts_detector = NCNNPTSDetector(
|
||||||
[''] * 4, [''] * 80,
|
[''] * 4, [''] * 80,
|
||||||
model_cfg=model_cfg,
|
model_cfg=model_cfg,
|
||||||
deploy_cfg=deploy_cfg,
|
deploy_cfg=deploy_cfg,
|
||||||
device_id=0)
|
device_id=0)
|
||||||
|
|
||||||
imgs = [torch.rand(1, 3, 32, 32)]
|
imgs = [torch.rand(1, 3, 32, 32)]
|
||||||
img_metas = [[{
|
img_metas = [[{
|
||||||
'ori_shape': [32, 32, 3],
|
'ori_shape': [32, 32, 3],
|
||||||
'img_shape': [32, 32, 3],
|
'img_shape': [32, 32, 3],
|
||||||
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
'scale_factor': [2.09, 1.87, 2.09, 1.87],
|
||||||
}]]
|
}]]
|
||||||
results = ncnn_pts_detector.forward(imgs, img_metas)
|
results = ncnn_pts_detector.forward(imgs, img_metas)
|
||||||
assert results is not None, 'failed to get output using '
|
assert results is not None, 'failed to get output using '
|
||||||
'NCNNPTSDetector'
|
'NCNNPTSDetector'
|
||||||
SwitchBackendWrapper.recover(NCNNWrapper)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -541,9 +528,8 @@ def test_build_detector():
|
|||||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
# simplify backend inference
|
# simplify backend inference
|
||||||
SwitchBackendWrapper.set(
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
ORTWrapper, model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
from mmdeploy.apis.utils import init_backend_model
|
from mmdeploy.apis.utils import init_backend_model
|
||||||
detector = init_backend_model([''], model_cfg, deploy_cfg, -1)
|
detector = init_backend_model([''], model_cfg, deploy_cfg, -1)
|
||||||
assert detector is not None
|
assert detector is not None
|
||||||
SwitchBackendWrapper.recover(ORTWrapper)
|
|
||||||
|
BIN
tests/test_mmedit/data/imgs/blank.jpg
Normal file
BIN
tests/test_mmedit/data/imgs/blank.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 691 B |
110
tests/test_mmedit/data/model.py
Normal file
110
tests/test_mmedit/data/model.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
exp_name = 'srcnn_x4k915_g1_1000k_div2k'
|
||||||
|
|
||||||
|
scale = 1
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='BasicRestorer',
|
||||||
|
generator=dict(
|
||||||
|
type='SRCNN',
|
||||||
|
channels=(3, 64, 32, 3),
|
||||||
|
kernel_sizes=(9, 1, 5),
|
||||||
|
upscale_factor=scale),
|
||||||
|
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'))
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg = None
|
||||||
|
test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=scale)
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
train_dataset_type = 'SRAnnotationDataset'
|
||||||
|
val_dataset_type = 'SRFolderDataset'
|
||||||
|
train_pipeline = [
|
||||||
|
dict(
|
||||||
|
type='LoadImageFromFile',
|
||||||
|
io_backend='disk',
|
||||||
|
key='lq',
|
||||||
|
flag='unchanged'),
|
||||||
|
dict(
|
||||||
|
type='LoadImageFromFile',
|
||||||
|
io_backend='disk',
|
||||||
|
key='gt',
|
||||||
|
flag='unchanged'),
|
||||||
|
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
|
||||||
|
dict(
|
||||||
|
type='Normalize',
|
||||||
|
keys=['lq', 'gt'],
|
||||||
|
mean=[0, 0, 0],
|
||||||
|
std=[1, 1, 1],
|
||||||
|
to_rgb=True),
|
||||||
|
dict(type='PairedRandomCrop', gt_patch_size=128),
|
||||||
|
dict(
|
||||||
|
type='Flip', keys=['lq', 'gt'], flip_ratio=0.5,
|
||||||
|
direction='horizontal'),
|
||||||
|
dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, direction='vertical'),
|
||||||
|
dict(type='RandomTransposeHW', keys=['lq', 'gt'], transpose_ratio=0.5),
|
||||||
|
dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path']),
|
||||||
|
dict(type='ImageToTensor', keys=['lq', 'gt'])
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(
|
||||||
|
type='LoadImageFromFile',
|
||||||
|
io_backend='disk',
|
||||||
|
key='lq',
|
||||||
|
flag='unchanged'),
|
||||||
|
dict(
|
||||||
|
type='LoadImageFromFile',
|
||||||
|
io_backend='disk',
|
||||||
|
key='gt',
|
||||||
|
flag='unchanged'),
|
||||||
|
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
|
||||||
|
dict(
|
||||||
|
type='Normalize',
|
||||||
|
keys=['lq', 'gt'],
|
||||||
|
mean=[0, 0, 0],
|
||||||
|
std=[1, 1, 1],
|
||||||
|
to_rgb=True),
|
||||||
|
dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'lq_path']),
|
||||||
|
dict(type='ImageToTensor', keys=['lq', 'gt'])
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
workers_per_gpu=8,
|
||||||
|
train_dataloader=dict(samples_per_gpu=16, drop_last=True),
|
||||||
|
val_dataloader=dict(samples_per_gpu=1),
|
||||||
|
test_dataloader=dict(samples_per_gpu=1),
|
||||||
|
test=dict(
|
||||||
|
type=val_dataset_type,
|
||||||
|
lq_folder='tests/test_mmedit/data/imgs',
|
||||||
|
gt_folder='tests/test_mmedit/data/imgs',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
scale=scale,
|
||||||
|
filename_tmpl='{}'))
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizers = dict(generator=dict(type='Adam', lr=2e-4, betas=(0.9, 0.999)))
|
||||||
|
|
||||||
|
# learning policy
|
||||||
|
total_iters = 1000000
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineRestart',
|
||||||
|
by_epoch=False,
|
||||||
|
periods=[250000, 250000, 250000, 250000],
|
||||||
|
restart_weights=[1, 1, 1, 1],
|
||||||
|
min_lr=1e-7)
|
||||||
|
|
||||||
|
checkpoint_config = dict(interval=5000, save_optimizer=True, by_epoch=False)
|
||||||
|
evaluation = dict(interval=5000, save_image=True, gpu_collect=True)
|
||||||
|
log_config = dict(
|
||||||
|
interval=100,
|
||||||
|
hooks=[
|
||||||
|
dict(type='TextLoggerHook', by_epoch=False),
|
||||||
|
dict(type='TensorboardLoggerHook'),
|
||||||
|
])
|
||||||
|
visual_config = None
|
||||||
|
|
||||||
|
# runtime settings
|
||||||
|
dist_params = dict(backend='nccl')
|
||||||
|
log_level = 'INFO'
|
||||||
|
work_dir = f'./work_dirs/{exp_name}'
|
||||||
|
load_from = None
|
||||||
|
resume_from = None
|
||||||
|
workflow = [('train', 1)]
|
219
tests/test_mmedit/test_mmedit_apis.py
Normal file
219
tests/test_mmedit/test_mmedit_apis.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import mmdeploy.apis.onnxruntime as ort_apis
|
||||||
|
import mmdeploy.apis.ppl as ppl_apis
|
||||||
|
import mmdeploy.apis.tensorrt as trt_apis
|
||||||
|
import mmdeploy.apis.test as api_test
|
||||||
|
import mmdeploy.apis.utils as api_utils
|
||||||
|
from mmdeploy.utils.constants import Backend, Codebase
|
||||||
|
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||||
|
def test_TensorRTRestorer():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||||
|
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = {
|
||||||
|
'output': torch.rand(1, 3, 64, 64).cuda(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmedit.apis.inference import TensorRTRestorer
|
||||||
|
trt_restorer = TensorRTRestorer('', 0)
|
||||||
|
imgs = torch.rand(1, 3, 64, 64).cuda()
|
||||||
|
|
||||||
|
results = trt_restorer.forward(imgs)
|
||||||
|
assert results is not None, ('failed to get output using '
|
||||||
|
'TensorRTRestorer')
|
||||||
|
|
||||||
|
results = trt_restorer.forward(imgs, test_mode=True)
|
||||||
|
assert results is not None, ('failed to get output using '
|
||||||
|
'TensorRTRestorer')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_ONNXRuntimeRestorer():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmedit.apis.inference import ONNXRuntimeRestorer
|
||||||
|
ort_restorer = ONNXRuntimeRestorer('', 0)
|
||||||
|
imgs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
results = ort_restorer.forward(imgs)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'ONNXRuntimeRestorer'
|
||||||
|
|
||||||
|
results = ort_restorer.forward(imgs, test_mode=True)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'ONNXRuntimeRestorer'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
|
||||||
|
def test_PPLRestorer():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
|
||||||
|
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
with SwitchBackendWrapper(PPLWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmedit.apis.inference import PPLRestorer
|
||||||
|
ppl_restorer = PPLRestorer('', 0)
|
||||||
|
imgs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
results = ppl_restorer.forward(imgs)
|
||||||
|
assert results is not None, 'failed to get output using PPLRestorer'
|
||||||
|
|
||||||
|
results = ppl_restorer.forward(imgs, test_mode=True)
|
||||||
|
assert results is not None, 'failed to get output using PPLRestorer'
|
||||||
|
|
||||||
|
|
||||||
|
model_cfg = 'tests/test_mmedit/data/model.py'
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='onnxruntime'),
|
||||||
|
codebase_config=dict(type='mmedit', task='SuperResolution'),
|
||||||
|
onnx_config=dict(
|
||||||
|
type='onnx',
|
||||||
|
export_params=True,
|
||||||
|
keep_initializers_as_inputs=False,
|
||||||
|
opset_version=11,
|
||||||
|
input_shape=None,
|
||||||
|
input_names=['input'],
|
||||||
|
output_names=['output'])))
|
||||||
|
input_img = torch.rand(1, 3, 64, 64)
|
||||||
|
input = {'lq': input_img}
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_pytorch_model():
|
||||||
|
model = api_utils.init_pytorch_model(
|
||||||
|
Codebase.MMEDIT, model_cfg=model_cfg, device='cpu')
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def create_backend_model():
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
|
||||||
|
wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||||
|
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
|
model = api_utils.init_backend_model([''], model_cfg, deploy_cfg)
|
||||||
|
|
||||||
|
return model, wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_init_backend_model():
|
||||||
|
model, wrapper = create_backend_model()
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
# Recovery
|
||||||
|
wrapper.recover()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_run_inference():
|
||||||
|
model, wrapper = create_backend_model()
|
||||||
|
result = api_utils.run_inference(Codebase.MMEDIT, input, model)
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
# Recovery
|
||||||
|
wrapper.recover()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_visualize():
|
||||||
|
model, wrapper = create_backend_model()
|
||||||
|
result = api_utils.run_inference(Codebase.MMEDIT, input, model)
|
||||||
|
with tempfile.TemporaryDirectory() as dir:
|
||||||
|
filename = dir + 'tmp.jpg'
|
||||||
|
api_utils.visualize(Codebase.MMEDIT, input, result, model, filename,
|
||||||
|
Backend.ONNXRUNTIME)
|
||||||
|
assert os.path.exists(filename)
|
||||||
|
|
||||||
|
# Recovery
|
||||||
|
wrapper.recover()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_inference_model():
|
||||||
|
numpy_img = np.random.rand(64, 64, 3)
|
||||||
|
with tempfile.TemporaryDirectory() as dir:
|
||||||
|
filename = dir + 'tmp.jpg'
|
||||||
|
model, wrapper = create_backend_model()
|
||||||
|
from mmdeploy.apis.inference import inference_model
|
||||||
|
inference_model(model_cfg, deploy_cfg, model, numpy_img, 'cpu',
|
||||||
|
Backend.ONNXRUNTIME, filename, False)
|
||||||
|
assert os.path.exists(filename)
|
||||||
|
|
||||||
|
# Recovery
|
||||||
|
wrapper.recover()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
def test_test():
|
||||||
|
from mmcv.parallel import MMDataParallel
|
||||||
|
with tempfile.TemporaryDirectory() as dir:
|
||||||
|
|
||||||
|
# Export a complete model
|
||||||
|
numpy_img = np.random.rand(50, 50, 3)
|
||||||
|
onnx_filename = 'end2end.onnx'
|
||||||
|
onnx_path = os.path.join(dir, onnx_filename)
|
||||||
|
from mmdeploy.apis import torch2onnx
|
||||||
|
torch2onnx(numpy_img, dir, onnx_filename, deploy_cfg, model_cfg)
|
||||||
|
assert os.path.exists(onnx_path)
|
||||||
|
|
||||||
|
# Prepare dataloader
|
||||||
|
dataset = api_utils.build_dataset(
|
||||||
|
Codebase.MMEDIT, model_cfg, dataset_type='test')
|
||||||
|
assert dataset is not None, 'Failed to build dataset'
|
||||||
|
dataloader = api_utils.build_dataloader(Codebase.MMEDIT, dataset, 1, 1)
|
||||||
|
assert dataloader is not None, 'Failed to build dataloader'
|
||||||
|
|
||||||
|
# Prepare model
|
||||||
|
model = api_utils.init_backend_model([onnx_path], model_cfg,
|
||||||
|
deploy_cfg)
|
||||||
|
model = MMDataParallel(model, device_ids=[0])
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
# Run test
|
||||||
|
outputs = api_test.single_gpu_test(Codebase.MMEDIT, model, dataloader)
|
||||||
|
assert outputs is not None
|
||||||
|
api_test.post_process_outputs(outputs, dataset, model_cfg,
|
||||||
|
Codebase.MMEDIT)
|
97
tests/test_mmedit/test_mmedit_export.py
Normal file
97
tests/test_mmedit/test_mmedit_export.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmdeploy.apis.utils import build_dataloader, build_dataset, create_input
|
||||||
|
from mmdeploy.utils.constants import Codebase, Task
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateInput:
|
||||||
|
task = Task.SUPER_RESOLUTION
|
||||||
|
img_test_pipeline = [
|
||||||
|
dict(
|
||||||
|
type='LoadImageFromFile',
|
||||||
|
io_backend='disk',
|
||||||
|
key='lq',
|
||||||
|
flag='unchanged'),
|
||||||
|
dict(
|
||||||
|
type='LoadImageFromFile',
|
||||||
|
io_backend='disk',
|
||||||
|
key='gt',
|
||||||
|
flag='unchanged'),
|
||||||
|
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
|
||||||
|
dict(
|
||||||
|
type='Normalize',
|
||||||
|
keys=['lq', 'gt'],
|
||||||
|
mean=[0, 0, 0],
|
||||||
|
std=[1, 1, 1],
|
||||||
|
to_rgb=True),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['lq', 'gt'],
|
||||||
|
meta_keys=['lq_path', 'lq_path']),
|
||||||
|
dict(type='ImageToTensor', keys=['lq', 'gt'])
|
||||||
|
]
|
||||||
|
|
||||||
|
imgs = np.random.rand(32, 32, 3)
|
||||||
|
img_path = 'tests/test_mmedit/data/imgs/blank.jpg'
|
||||||
|
|
||||||
|
def test_create_input_static(this):
|
||||||
|
data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
model_cfg = mmcv.Config(
|
||||||
|
dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
inputs = create_input(
|
||||||
|
Codebase.MMEDIT,
|
||||||
|
TestCreateInput.task,
|
||||||
|
model_cfg,
|
||||||
|
TestCreateInput.imgs,
|
||||||
|
input_shape=(32, 32),
|
||||||
|
device='cpu')
|
||||||
|
assert inputs is not None, 'Failed to create input'
|
||||||
|
|
||||||
|
def test_create_input_dynamic(this):
|
||||||
|
data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
model_cfg = mmcv.Config(
|
||||||
|
dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
inputs = create_input(
|
||||||
|
Codebase.MMEDIT,
|
||||||
|
TestCreateInput.task,
|
||||||
|
model_cfg,
|
||||||
|
TestCreateInput.imgs,
|
||||||
|
input_shape=None,
|
||||||
|
device='cpu')
|
||||||
|
assert inputs is not None, 'Failed to create input'
|
||||||
|
|
||||||
|
def test_create_input_from_file(this):
|
||||||
|
data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
model_cfg = mmcv.Config(
|
||||||
|
dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
inputs = create_input(
|
||||||
|
Codebase.MMEDIT,
|
||||||
|
TestCreateInput.task,
|
||||||
|
model_cfg,
|
||||||
|
TestCreateInput.img_path,
|
||||||
|
input_shape=None,
|
||||||
|
device='cpu')
|
||||||
|
assert inputs is not None, 'Failed to create input'
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_dataset():
|
||||||
|
data = dict(
|
||||||
|
test={
|
||||||
|
'type': 'SRFolderDataset',
|
||||||
|
'lq_folder': 'tests/test_mmedit/data/imgs',
|
||||||
|
'gt_folder': 'tests/test_mmedit/data/imgs',
|
||||||
|
'scale': 1,
|
||||||
|
'filename_tmpl': '{}',
|
||||||
|
'pipeline': [
|
||||||
|
{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
dataset_cfg = mmcv.Config(dict(data=data))
|
||||||
|
dataset = build_dataset(
|
||||||
|
Codebase.MMEDIT, dataset_cfg=dataset_cfg, dataset_type='test')
|
||||||
|
assert dataset is not None, 'Failed to build dataset'
|
||||||
|
dataloader = build_dataloader(Codebase.MMEDIT, dataset, 1, 1)
|
||||||
|
assert dataloader is not None, 'Failed to build dataloader'
|
72
tests/test_mmedit/test_mmedit_models.py
Normal file
72
tests/test_mmedit/test_mmedit_models.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
from mmedit.models.backbones.sr_backbones import SRCNN
|
||||||
|
|
||||||
|
from mmdeploy.core import RewriterContext
|
||||||
|
from mmdeploy.utils import Backend, get_onnx_config
|
||||||
|
|
||||||
|
img = torch.rand(1, 3, 4, 4)
|
||||||
|
model_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||||
|
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmedit',
|
||||||
|
task='SuperResolution',
|
||||||
|
),
|
||||||
|
backend_config=dict(
|
||||||
|
type='tensorrt',
|
||||||
|
common_config=dict(fp16_mode=False, max_workspace_size=1 << 10),
|
||||||
|
model_inputs=[
|
||||||
|
dict(
|
||||||
|
input_shapes=dict(
|
||||||
|
input=dict(
|
||||||
|
min_shape=[1, 3, 4, 4],
|
||||||
|
opt_shape=[1, 3, 4, 4],
|
||||||
|
max_shape=[1, 3, 4, 4])))
|
||||||
|
]),
|
||||||
|
onnx_config=dict(
|
||||||
|
type='onnx',
|
||||||
|
export_params=True,
|
||||||
|
keep_initializers_as_inputs=False,
|
||||||
|
opset_version=11,
|
||||||
|
save_file=model_file,
|
||||||
|
input_shape=None,
|
||||||
|
input_names=['input'],
|
||||||
|
output_names=['output'])))
|
||||||
|
|
||||||
|
|
||||||
|
def test_srcnn():
|
||||||
|
pytorch_model = SRCNN()
|
||||||
|
model_inputs = {'x': img}
|
||||||
|
|
||||||
|
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||||
|
pytorch2onnx_cfg = get_onnx_config(deploy_cfg)
|
||||||
|
input_names = [k for k, v in model_inputs.items() if k != 'ctx']
|
||||||
|
with RewriterContext(
|
||||||
|
cfg=deploy_cfg, backend=Backend.TENSORRT.value), torch.no_grad():
|
||||||
|
torch.onnx.export(
|
||||||
|
pytorch_model,
|
||||||
|
tuple([v for k, v in model_inputs.items()]),
|
||||||
|
onnx_file_path,
|
||||||
|
export_params=True,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=None,
|
||||||
|
opset_version=11,
|
||||||
|
dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None),
|
||||||
|
keep_initializers_as_inputs=False)
|
||||||
|
|
||||||
|
# The result should be different due to the rewrite.
|
||||||
|
# So we only check if the file exists
|
||||||
|
assert osp.exists(onnx_file_path)
|
||||||
|
|
||||||
|
model = onnx.load(onnx_file_path)
|
||||||
|
assert model is not None
|
||||||
|
try:
|
||||||
|
onnx.checker.check_model(model)
|
||||||
|
except onnx.checker.ValidationError:
|
||||||
|
assert False
|
@ -121,7 +121,7 @@ class TestTensorRTExporter:
|
|||||||
backend_config=dict(
|
backend_config=dict(
|
||||||
type='tensorrt',
|
type='tensorrt',
|
||||||
common_config=dict(
|
common_config=dict(
|
||||||
fp16_mode=False, max_workspace_size=1 << 30),
|
fp16_mode=False, max_workspace_size=1 << 28),
|
||||||
model_inputs=[
|
model_inputs=[
|
||||||
dict(
|
dict(
|
||||||
input_shapes=dict(
|
input_shapes=dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user