diff --git a/mmdeploy/apis/utils.py b/mmdeploy/apis/utils.py index 540c0b144..09d4618d4 100644 --- a/mmdeploy/apis/utils.py +++ b/mmdeploy/apis/utils.py @@ -57,7 +57,7 @@ def init_pytorch_model(codebase: Codebase, def create_input(codebase: Codebase, task: Task, model_cfg: Union[str, mmcv.Config], - imgs: Any, + imgs: Union[str, np.ndarray], input_shape: Sequence[int] = None, device: str = 'cuda:0', **kwargs): @@ -68,8 +68,8 @@ def create_input(codebase: Codebase, task (Task): Specifying task type. model_cfg (str | mmcv.Config): Model config file or loaded Config object. - imgs (Any): Input image(s), accpeted data type are `str`, - `np.ndarray`, `torch.Tensor`. + imgs (str | np.ndarray): Input image(s), accpeted data types are `str`, + `np.ndarray`. input_shape (list[int]): Input shape of image in (width, height) format, defaults to `None`. device (str): A string specifying device type, defaults to 'cuda:0'. diff --git a/mmdeploy/mmedit/export/prepare_input.py b/mmdeploy/mmedit/export/prepare_input.py index 99dfff9dc..c7293eca6 100644 --- a/mmdeploy/mmedit/export/prepare_input.py +++ b/mmdeploy/mmedit/export/prepare_input.py @@ -11,32 +11,65 @@ from torch.utils.data.dataset import Dataset from mmdeploy.utils import Task, load_config -def _preprocess_cfg(config: Union[str, mmcv.Config]): +def _preprocess_cfg(config: Union[str, mmcv.Config], task: Task, + load_from_file: bool, is_static_cfg: bool, + input_shape: Sequence[int]): """Remove unnecessary information in config. Args: model_cfg (str | mmcv.Config): The input model config. + task (Task): Specifying editing task type. + load_from_file (bool): Whether the input is a filename of a numpy + matrix. If this variable is True, extra preprocessing is required. + is_static_cfg (bool): Whether the config specifys a static export. + If this variable if True, the input image will be resize to a fix + resolution. + input_shape (Sequence[int]): A list of two integer in (width, height) + format specifying input shape. Defaults to `None`. """ # TODO: Differentiate the editing tasks (e.g. restorers and mattors # preprocess the data in differenet ways) - keys_to_remove = ['gt', 'gt_path'] + if task == Task.SUPER_RESOLUTION: + keys_to_remove = ['gt', 'gt_path'] + else: + raise NotImplementedError(f'Unknown task type: {task.value}') + + # MMEdit doesn't support LoadImageFromWebcam. + # Remove "LoadImageFromFile" and related metakeys. + if not load_from_file: + config.test_pipeline.pop(0) + if task == Task.SUPER_RESOLUTION: + keys_to_remove.append('lq_path') + + # Fix the input shape by 'Resize' + if is_static_cfg: + if task == Task.SUPER_RESOLUTION: + resize = { + 'type': 'Resize', + 'scale': (input_shape[0], input_shape[1]), + 'keys': ['lq'] + } + config.test_pipeline.insert(1, resize) + for key in keys_to_remove: for pipeline in list(config.test_pipeline): if 'key' in pipeline and key == pipeline['key']: config.test_pipeline.remove(pipeline) - if 'keys' in pipeline and key in pipeline['keys']: - pipeline['keys'].remove(key) + if 'keys' in pipeline: + while key in pipeline['keys']: + pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: config.test_pipeline.remove(pipeline) - if 'meta_keys' in pipeline and key in pipeline['meta_keys']: - pipeline['meta_keys'].remove(key) + if 'meta_keys' in pipeline: + while key in pipeline['meta_keys']: + pipeline['meta_keys'].remove(key) def create_input(task: Task, model_cfg: Union[str, mmcv.Config], - imgs: Union[str, mmcv.Config], + imgs: Union[str, np.ndarray], input_shape: Optional[Sequence[int]] = None, device: Optional[str] = 'cuda:0'): """Create input for editing processor. @@ -61,38 +94,30 @@ def create_input(task: Task, raise AssertionError('imgs must be strings or numpy arrays') cfg = load_config(model_cfg)[0].copy() - _preprocess_cfg(cfg) - if isinstance(imgs[0], np.ndarray): - cfg = cfg.copy() - # set loading pipeline type - cfg.test_pipeline[0].type = 'LoadImageFromWebcam' - - # for static exporting - if input_shape is not None: - if task == Task.SUPER_RESOLUTION: - resize = { - 'type': 'Resize', - 'scale': (input_shape[0], input_shape[1]), - 'keys': ['lq'] - } - cfg.test_pipeline.insert(1, resize) - else: - raise NotImplementedError(f'Unknown task type: {task.value}') + _preprocess_cfg( + cfg, + task=task, + load_from_file=isinstance(imgs[0], str), + is_static_cfg=input_shape is not None, + input_shape=input_shape) test_pipeline = Compose(cfg.test_pipeline) data_arr = [] for img in imgs: - # TODO: This is only for restore. Add condiction statement - data = dict(lq_path=img) + # TODO: This is only for restore. Add condiction statement. + if isinstance(img, np.ndarray): + data = dict(lq=img) + else: + data = dict(lq_path=img) data = test_pipeline(data) data_arr.append(data) data = collate(data_arr, samples_per_gpu=len(imgs)) - # TODO: This is only for restore. Add condiction statement + # TODO: This is only for restore. Add condiction statement. data['img'] = data['lq'] if device != 'cpu': diff --git a/mmdeploy/utils/test.py b/mmdeploy/utils/test.py index 049b1e3e6..ce51c6422 100644 --- a/mmdeploy/utils/test.py +++ b/mmdeploy/utils/test.py @@ -59,13 +59,14 @@ class WrapModel(nn.Module): class SwitchBackendWrapper: """A switcher for backend wrapper for unit tests. - Examples: >>> from mmdeploy.utils.test import SwitchBackendWrapper >>> from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper - >>> SwitchBackendWrapper.set(ORTWrapper, outputs=outputs) + >>> with SwitchBackendWrapper(ORTWrapper) as wrapper: + >>> wrapper.set(ORTWrapper, outputs=outputs) + >>> ... + >>> # ORTWrapper will recover when exiting context >>> ... - >>> SwitchBackendWrapper.recover(ORTWrapper) """ init = None forward = None @@ -83,26 +84,35 @@ class SwitchBackendWrapper: def __call__(self, *args, **kwds): return self.forward(*args, **kwds) - @staticmethod - def set(obj, **kwargs): + def __init__(self, recover_class): + self._recover_class = recover_class + + def __enter__(self): + return self + + def __exit__(self, type, value, trace): + self.recover() + + def set(self, **kwargs): """Replace attributes in backend wrappers with dummy items.""" - SwitchBackendWrapper.init = obj.__init__ - SwitchBackendWrapper.forward = obj.forward - SwitchBackendWrapper.call = obj.__call__ + obj = self._recover_class + self.init = obj.__init__ + self.forward = obj.forward + self.call = obj.__call__ obj.__init__ = SwitchBackendWrapper.BackendWrapper.__init__ obj.forward = SwitchBackendWrapper.BackendWrapper.forward obj.__call__ = SwitchBackendWrapper.BackendWrapper.__call__ for k, v in kwargs.items(): setattr(obj, k, v) - @staticmethod - def recover(obj): - assert SwitchBackendWrapper.init is not None and \ - SwitchBackendWrapper.forward is not None,\ + def recover(self): + assert self.init is not None and \ + self.forward is not None,\ 'recover method must be called after exchange' - obj.__init__ = SwitchBackendWrapper.init - obj.forward = SwitchBackendWrapper.forward - obj.__call__ = SwitchBackendWrapper.call + obj = self._recover_class + obj.__init__ = self.init + obj.forward = self.forward + obj.__call__ = self.call def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]], diff --git a/tests/test_apis/test_calibration.py b/tests/test_apis/test_calibration.py index 26b1594a9..17f6f3689 100644 --- a/tests/test_apis/test_calibration.py +++ b/tests/test_apis/test_calibration.py @@ -1,5 +1,6 @@ import os.path as osp import tempfile +from multiprocessing import Process import h5py import mmcv @@ -7,7 +8,6 @@ import mmcv from mmdeploy.apis import create_calib_table calib_file = tempfile.NamedTemporaryFile(suffix='.h5').name -data_prefix = 'tests/data/tiger' ann_file = 'tests/data/annotation.json' @@ -71,7 +71,7 @@ def get_model_cfg(): dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', - img_scale=(1333, 800), + img_scale=(1, 1), flip=False, transforms=[ dict(type='Resize', keep_ratio=True), @@ -169,7 +169,7 @@ def get_model_cfg(): return model_cfg -def test_create_calib_end2end(): +def run_test_create_calib_end2end(): model_cfg = get_model_cfg() deploy_cfg = get_end2end_deploy_cfg() create_calib_table( @@ -189,7 +189,19 @@ def test_create_calib_end2end(): assert calibrator['calib_data']['end2end']['input']['0'] is not None -def test_create_calib_parittion(): +# Because Faster-RCNN needs too much memory on GPU, we need to run tests in a +# new process. + + +def test_create_calib_end2end(): + p = Process(target=run_test_create_calib_end2end) + try: + p.start() + finally: + p.join() + + +def run_test_create_calib_parittion(): model_cfg = get_model_cfg() deploy_cfg = get_partition_deploy_cfg() create_calib_table( @@ -211,3 +223,11 @@ def test_create_calib_parittion(): assert calib_data[partition_name] is not None assert calib_data[partition_name][input_names[i]] is not None assert calib_data[partition_name][input_names[i]]['0'] is not None + + +def test_create_calib_parittion(): + p = Process(target=run_test_create_calib_parittion) + try: + p.start() + finally: + p.join() diff --git a/tests/test_mmdet/test_mmdet_apis.py b/tests/test_mmdet/test_mmdet_apis.py index 03079291c..cc4770df5 100644 --- a/tests/test_mmdet/test_mmdet_apis.py +++ b/tests/test_mmdet/test_mmdet_apis.py @@ -26,20 +26,21 @@ def test_TensorRTDetector(): 'dets': torch.rand(1, 100, 5).cuda(), 'labels': torch.rand(1, 100).cuda() } - SwitchBackendWrapper.set(TRTWrapper, outputs=outputs) + with SwitchBackendWrapper(TRTWrapper) as wrapper: + wrapper.set(outputs=outputs) - from mmdeploy.mmdet.apis.inference import TensorRTDetector - trt_detector = TensorRTDetector('', ['' for i in range(80)], 0) - imgs = [torch.rand(1, 3, 64, 64).cuda()] - img_metas = [[{ - 'ori_shape': [64, 64, 3], - 'img_shape': [64, 64, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] + from mmdeploy.mmdet.apis.inference import TensorRTDetector + trt_detector = TensorRTDetector('', ['' for i in range(80)], 0) + imgs = [torch.rand(1, 3, 64, 64).cuda()] + img_metas = [[{ + 'ori_shape': [64, 64, 3], + 'img_shape': [64, 64, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] - results = trt_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using TensorRTDetector' - SwitchBackendWrapper.recover(TRTWrapper) + results = trt_detector.forward(imgs, img_metas) + assert results is not None, ('failed to get output using ' + 'TensorRTDetector') @pytest.mark.skipif( @@ -52,21 +53,21 @@ def test_ONNXRuntimeDetector(): # simplify backend inference outputs = (torch.rand(1, 100, 5), torch.rand(1, 100)) - SwitchBackendWrapper.set(ORTWrapper, outputs=outputs) + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(outputs=outputs) - from mmdeploy.mmdet.apis.inference import ONNXRuntimeDetector - ort_detector = ONNXRuntimeDetector('', ['' for i in range(80)], 0) - imgs = [torch.rand(1, 3, 64, 64)] - img_metas = [[{ - 'ori_shape': [64, 64, 3], - 'img_shape': [64, 64, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] + from mmdeploy.mmdet.apis.inference import ONNXRuntimeDetector + ort_detector = ONNXRuntimeDetector('', ['' for i in range(80)], 0) + imgs = [torch.rand(1, 3, 64, 64)] + img_metas = [[{ + 'ori_shape': [64, 64, 3], + 'img_shape': [64, 64, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] - results = ort_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using '\ - 'ONNXRuntimeDetector' - SwitchBackendWrapper.recover(ORTWrapper) + results = ort_detector.forward(imgs, img_metas) + assert results is not None, 'failed to get output using '\ + 'ONNXRuntimeDetector' @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') @@ -80,20 +81,20 @@ def test_PPLDetector(): # simplify backend inference outputs = (torch.rand(1, 100, 5), torch.rand(1, 100)) - SwitchBackendWrapper.set(PPLWrapper, outputs=outputs) + with SwitchBackendWrapper(PPLWrapper) as wrapper: + wrapper.set(outputs=outputs) - from mmdeploy.mmdet.apis.inference import PPLDetector - ppl_detector = PPLDetector('', ['' for i in range(80)], 0) - imgs = [torch.rand(1, 3, 64, 64)] - img_metas = [[{ - 'ori_shape': [64, 64, 3], - 'img_shape': [64, 64, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] + from mmdeploy.mmdet.apis.inference import PPLDetector + ppl_detector = PPLDetector('', ['' for i in range(80)], 0) + imgs = [torch.rand(1, 3, 64, 64)] + img_metas = [[{ + 'ori_shape': [64, 64, 3], + 'img_shape': [64, 64, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] - results = ppl_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using PPLDetector' - SwitchBackendWrapper.recover(PPLWrapper) + results = ppl_detector.forward(imgs, img_metas) + assert results is not None, 'failed to get output using PPLDetector' def get_test_cfg_and_post_processing(): @@ -155,28 +156,26 @@ def test_NCNNPSSDetector(): 'scores': torch.rand(1, 120, 80), 'boxes': torch.rand(1, 120, 4) } - SwitchBackendWrapper.set( - NCNNWrapper, - outputs=outputs, - model_cfg=model_cfg, - deploy_cfg=deploy_cfg) + with SwitchBackendWrapper(NCNNWrapper) as wrapper: + wrapper.set( + outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg) - from mmdeploy.mmdet.apis.inference import NCNNPSSDetector + from mmdeploy.mmdet.apis.inference import NCNNPSSDetector - ncnn_pss_detector = NCNNPSSDetector(['', ''], ['' for i in range(80)], - model_cfg=model_cfg, - deploy_cfg=deploy_cfg, - device_id=0) - imgs = [torch.rand(1, 3, 32, 32)] - img_metas = [[{ - 'ori_shape': [32, 32, 3], - 'img_shape': [32, 32, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] + ncnn_pss_detector = NCNNPSSDetector(['', ''], ['' for i in range(80)], + model_cfg=model_cfg, + deploy_cfg=deploy_cfg, + device_id=0) + imgs = [torch.rand(1, 3, 32, 32)] + img_metas = [[{ + 'ori_shape': [32, 32, 3], + 'img_shape': [32, 32, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] - results = ncnn_pss_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using NCNNPSSDetector' - SwitchBackendWrapper.recover(NCNNWrapper) + results = ncnn_pss_detector.forward(imgs, img_metas) + assert results is not None, ('failed to get output using ' + 'NCNNPSSDetector') @pytest.mark.skipif( @@ -197,30 +196,27 @@ def test_ONNXRuntimePSSDetector(): np.random.rand(1, 120, 80).astype(np.float32), np.random.rand(1, 120, 4).astype(np.float32) ] - SwitchBackendWrapper.set( - ORTWrapper, - outputs=outputs, - model_cfg=model_cfg, - deploy_cfg=deploy_cfg) + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set( + outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg) - from mmdeploy.mmdet.apis.inference import ONNXRuntimePSSDetector + from mmdeploy.mmdet.apis.inference import ONNXRuntimePSSDetector - ort_pss_detector = ONNXRuntimePSSDetector( - '', ['' for i in range(80)], - model_cfg=model_cfg, - deploy_cfg=deploy_cfg, - device_id=0) - imgs = [torch.rand(1, 3, 32, 32)] - img_metas = [[{ - 'ori_shape': [32, 32, 3], - 'img_shape': [32, 32, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] + ort_pss_detector = ONNXRuntimePSSDetector( + '', ['' for i in range(80)], + model_cfg=model_cfg, + deploy_cfg=deploy_cfg, + device_id=0) + imgs = [torch.rand(1, 3, 32, 32)] + img_metas = [[{ + 'ori_shape': [32, 32, 3], + 'img_shape': [32, 32, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] - results = ort_pss_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using ' - 'ONNXRuntimePSSDetector' - SwitchBackendWrapper.recover(ORTWrapper) + results = ort_pss_detector.forward(imgs, img_metas) + assert results is not None, 'failed to get output using ' + 'ONNXRuntimePSSDetector' @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') @@ -242,30 +238,27 @@ def test_TensorRTPSSDetector(): 'scores': torch.rand(1, 120, 80).cuda(), 'boxes': torch.rand(1, 120, 4).cuda() } - SwitchBackendWrapper.set( - TRTWrapper, - outputs=outputs, - model_cfg=model_cfg, - deploy_cfg=deploy_cfg) + with SwitchBackendWrapper(TRTWrapper) as wrapper: + wrapper.set( + outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg) - from mmdeploy.mmdet.apis.inference import TensorRTPSSDetector + from mmdeploy.mmdet.apis.inference import TensorRTPSSDetector - trt_pss_detector = TensorRTPSSDetector( - '', ['' for i in range(80)], - model_cfg=model_cfg, - deploy_cfg=deploy_cfg, - device_id=0) - imgs = [torch.rand(1, 3, 32, 32).cuda()] - img_metas = [[{ - 'ori_shape': [32, 32, 3], - 'img_shape': [32, 32, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] + trt_pss_detector = TensorRTPSSDetector( + '', ['' for i in range(80)], + model_cfg=model_cfg, + deploy_cfg=deploy_cfg, + device_id=0) + imgs = [torch.rand(1, 3, 32, 32).cuda()] + img_metas = [[{ + 'ori_shape': [32, 32, 3], + 'img_shape': [32, 32, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] - results = trt_pss_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using ' - 'TensorRTPSSDetector' - SwitchBackendWrapper.recover(TRTWrapper) + results = trt_pss_detector.forward(imgs, img_metas) + assert results is not None, 'failed to get output using ' + 'TensorRTPSSDetector' def prepare_model_deploy_cfgs(): @@ -377,41 +370,41 @@ def test_TensorRTPTSDetector(): 'cls_score': torch.rand(1, 12, 80).cuda(), 'bbox_pred': torch.rand(1, 12, 4).cuda() } - SwitchBackendWrapper.set(TRTWrapper, outputs=outputs) - TRTWrapper.model_cfg = model_cfg - TRTWrapper.deploy_cfg = deploy_cfg + with SwitchBackendWrapper(TRTWrapper) as wrapper: + wrapper.set( + outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg) - # replace original function in PartitionTwoStageDetector - from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector - PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__ - PartitionTwoStageDetector.partition0_postprocess = \ - DummyPTSDetector.partition0_postprocess - PartitionTwoStageDetector.partition1_postprocess = \ - DummyPTSDetector.partition1_postprocess - PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3).cuda()] * 2 - PartitionTwoStageDetector.outputs1 = [ - torch.rand(1, 9, 5).cuda(), - torch.rand(1, 9).cuda() - ] - PartitionTwoStageDetector.device_id = 0 - PartitionTwoStageDetector.CLASSES = ['' for i in range(80)] + # replace original function in PartitionTwoStageDetector + from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector + PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__ + PartitionTwoStageDetector.partition0_postprocess = \ + DummyPTSDetector.partition0_postprocess + PartitionTwoStageDetector.partition1_postprocess = \ + DummyPTSDetector.partition1_postprocess + PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3).cuda()] * 2 + PartitionTwoStageDetector.outputs1 = [ + torch.rand(1, 9, 5).cuda(), + torch.rand(1, 9).cuda() + ] + PartitionTwoStageDetector.device_id = 0 + PartitionTwoStageDetector.CLASSES = ['' for i in range(80)] - from mmdeploy.mmdet.apis.inference import TensorRTPTSDetector - trt_pts_detector = TensorRTPTSDetector(['', ''], ['' for i in range(80)], - model_cfg=model_cfg, - deploy_cfg=deploy_cfg, - device_id=0) + from mmdeploy.mmdet.apis.inference import TensorRTPTSDetector + trt_pts_detector = TensorRTPTSDetector(['', ''], + ['' for i in range(80)], + model_cfg=model_cfg, + deploy_cfg=deploy_cfg, + device_id=0) - imgs = [torch.rand(1, 3, 32, 32).cuda()] - img_metas = [[{ - 'ori_shape': [32, 32, 3], - 'img_shape': [32, 32, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] - results = trt_pts_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using ' - 'TensorRTPTSDetector' - SwitchBackendWrapper.recover(TRTWrapper) + imgs = [torch.rand(1, 3, 32, 32).cuda()] + img_metas = [[{ + 'ori_shape': [32, 32, 3], + 'img_shape': [32, 32, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] + results = trt_pts_detector.forward(imgs, img_metas) + assert results is not None, 'failed to get output using ' + 'TensorRTPTSDetector' @pytest.mark.skipif( @@ -429,43 +422,40 @@ def test_ONNXRuntimePTSDetector(): np.random.rand(1, 12, 80).astype(np.float32), np.random.rand(1, 12, 4).astype(np.float32), ] * 2 - SwitchBackendWrapper.set( - ORTWrapper, - outputs=outputs, - model_cfg=model_cfg, - deploy_cfg=deploy_cfg) + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set( + outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg) - # replace original function in PartitionTwoStageDetector - from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector - PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__ - PartitionTwoStageDetector.partition0_postprocess = \ - DummyPTSDetector.partition0_postprocess - PartitionTwoStageDetector.partition1_postprocess = \ - DummyPTSDetector.partition1_postprocess - PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2 - PartitionTwoStageDetector.outputs1 = [ - torch.rand(1, 9, 5), torch.rand(1, 9) - ] - PartitionTwoStageDetector.device_id = -1 - PartitionTwoStageDetector.CLASSES = ['' for i in range(80)] + # replace original function in PartitionTwoStageDetector + from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector + PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__ + PartitionTwoStageDetector.partition0_postprocess = \ + DummyPTSDetector.partition0_postprocess + PartitionTwoStageDetector.partition1_postprocess = \ + DummyPTSDetector.partition1_postprocess + PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2 + PartitionTwoStageDetector.outputs1 = [ + torch.rand(1, 9, 5), torch.rand(1, 9) + ] + PartitionTwoStageDetector.device_id = -1 + PartitionTwoStageDetector.CLASSES = ['' for i in range(80)] - from mmdeploy.mmdet.apis.inference import ONNXRuntimePTSDetector - ort_pts_detector = ONNXRuntimePTSDetector(['', ''], - ['' for i in range(80)], - model_cfg=model_cfg, - deploy_cfg=deploy_cfg, - device_id=0) + from mmdeploy.mmdet.apis.inference import ONNXRuntimePTSDetector + ort_pts_detector = ONNXRuntimePTSDetector(['', ''], + ['' for i in range(80)], + model_cfg=model_cfg, + deploy_cfg=deploy_cfg, + device_id=0) - imgs = [torch.rand(1, 3, 32, 32)] - img_metas = [[{ - 'ori_shape': [32, 32, 3], - 'img_shape': [32, 32, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] - results = ort_pts_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using ' - 'ONNXRuntimePTSDetector' - SwitchBackendWrapper.recover(ORTWrapper) + imgs = [torch.rand(1, 3, 32, 32)] + img_metas = [[{ + 'ori_shape': [32, 32, 3], + 'img_shape': [32, 32, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] + results = ort_pts_detector.forward(imgs, img_metas) + assert results is not None, 'failed to get output using ' + 'ONNXRuntimePTSDetector' @pytest.mark.skipif( @@ -487,43 +477,40 @@ def test_NCNNPTSDetector(): 'cls_score': torch.rand(1, 12, 80), 'bbox_pred': torch.rand(1, 12, 4) } - SwitchBackendWrapper.set( - NCNNWrapper, - outputs=outputs, - model_cfg=model_cfg, - deploy_cfg=deploy_cfg) + with SwitchBackendWrapper(NCNNWrapper) as wrapper: + wrapper.set( + outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg) - # replace original function in PartitionTwoStageDetector - from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector - PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__ - PartitionTwoStageDetector.partition0_postprocess = \ - DummyPTSDetector.partition0_postprocess - PartitionTwoStageDetector.partition1_postprocess = \ - DummyPTSDetector.partition1_postprocess - PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2 - PartitionTwoStageDetector.outputs1 = [ - torch.rand(1, 9, 5), torch.rand(1, 9) - ] - PartitionTwoStageDetector.device_id = -1 - PartitionTwoStageDetector.CLASSES = ['' for i in range(80)] + # replace original function in PartitionTwoStageDetector + from mmdeploy.mmdet.apis.inference import PartitionTwoStageDetector + PartitionTwoStageDetector.__init__ = DummyPTSDetector.__init__ + PartitionTwoStageDetector.partition0_postprocess = \ + DummyPTSDetector.partition0_postprocess + PartitionTwoStageDetector.partition1_postprocess = \ + DummyPTSDetector.partition1_postprocess + PartitionTwoStageDetector.outputs0 = [torch.rand(2, 3)] * 2 + PartitionTwoStageDetector.outputs1 = [ + torch.rand(1, 9, 5), torch.rand(1, 9) + ] + PartitionTwoStageDetector.device_id = -1 + PartitionTwoStageDetector.CLASSES = ['' for i in range(80)] - from mmdeploy.mmdet.apis.inference import NCNNPTSDetector - ncnn_pts_detector = NCNNPTSDetector( - [''] * 4, [''] * 80, - model_cfg=model_cfg, - deploy_cfg=deploy_cfg, - device_id=0) + from mmdeploy.mmdet.apis.inference import NCNNPTSDetector + ncnn_pts_detector = NCNNPTSDetector( + [''] * 4, [''] * 80, + model_cfg=model_cfg, + deploy_cfg=deploy_cfg, + device_id=0) - imgs = [torch.rand(1, 3, 32, 32)] - img_metas = [[{ - 'ori_shape': [32, 32, 3], - 'img_shape': [32, 32, 3], - 'scale_factor': [2.09, 1.87, 2.09, 1.87], - }]] - results = ncnn_pts_detector.forward(imgs, img_metas) - assert results is not None, 'failed to get output using ' - 'NCNNPTSDetector' - SwitchBackendWrapper.recover(NCNNWrapper) + imgs = [torch.rand(1, 3, 32, 32)] + img_metas = [[{ + 'ori_shape': [32, 32, 3], + 'img_shape': [32, 32, 3], + 'scale_factor': [2.09, 1.87, 2.09, 1.87], + }]] + results = ncnn_pts_detector.forward(imgs, img_metas) + assert results is not None, 'failed to get output using ' + 'NCNNPTSDetector' @pytest.mark.skipif( @@ -541,9 +528,8 @@ def test_build_detector(): ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) # simplify backend inference - SwitchBackendWrapper.set( - ORTWrapper, model_cfg=model_cfg, deploy_cfg=deploy_cfg) - from mmdeploy.apis.utils import init_backend_model - detector = init_backend_model([''], model_cfg, deploy_cfg, -1) - assert detector is not None - SwitchBackendWrapper.recover(ORTWrapper) + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg) + from mmdeploy.apis.utils import init_backend_model + detector = init_backend_model([''], model_cfg, deploy_cfg, -1) + assert detector is not None diff --git a/tests/test_mmedit/data/imgs/blank.jpg b/tests/test_mmedit/data/imgs/blank.jpg new file mode 100644 index 000000000..ac446f47d Binary files /dev/null and b/tests/test_mmedit/data/imgs/blank.jpg differ diff --git a/tests/test_mmedit/data/model.py b/tests/test_mmedit/data/model.py new file mode 100644 index 000000000..289ece572 --- /dev/null +++ b/tests/test_mmedit/data/model.py @@ -0,0 +1,110 @@ +exp_name = 'srcnn_x4k915_g1_1000k_div2k' + +scale = 1 +# model settings +model = dict( + type='BasicRestorer', + generator=dict( + type='SRCNN', + channels=(3, 64, 32, 3), + kernel_sizes=(9, 1, 5), + upscale_factor=scale), + pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean')) +# model training and testing settings +train_cfg = None +test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=scale) + +# dataset settings +train_dataset_type = 'SRAnnotationDataset' +val_dataset_type = 'SRFolderDataset' +train_pipeline = [ + dict( + type='LoadImageFromFile', + io_backend='disk', + key='lq', + flag='unchanged'), + dict( + type='LoadImageFromFile', + io_backend='disk', + key='gt', + flag='unchanged'), + dict(type='RescaleToZeroOne', keys=['lq', 'gt']), + dict( + type='Normalize', + keys=['lq', 'gt'], + mean=[0, 0, 0], + std=[1, 1, 1], + to_rgb=True), + dict(type='PairedRandomCrop', gt_patch_size=128), + dict( + type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, + direction='horizontal'), + dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, direction='vertical'), + dict(type='RandomTransposeHW', keys=['lq', 'gt'], transpose_ratio=0.5), + dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path']), + dict(type='ImageToTensor', keys=['lq', 'gt']) +] +test_pipeline = [ + dict( + type='LoadImageFromFile', + io_backend='disk', + key='lq', + flag='unchanged'), + dict( + type='LoadImageFromFile', + io_backend='disk', + key='gt', + flag='unchanged'), + dict(type='RescaleToZeroOne', keys=['lq', 'gt']), + dict( + type='Normalize', + keys=['lq', 'gt'], + mean=[0, 0, 0], + std=[1, 1, 1], + to_rgb=True), + dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'lq_path']), + dict(type='ImageToTensor', keys=['lq', 'gt']) +] + +data = dict( + workers_per_gpu=8, + train_dataloader=dict(samples_per_gpu=16, drop_last=True), + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + test=dict( + type=val_dataset_type, + lq_folder='tests/test_mmedit/data/imgs', + gt_folder='tests/test_mmedit/data/imgs', + pipeline=test_pipeline, + scale=scale, + filename_tmpl='{}')) + +# optimizer +optimizers = dict(generator=dict(type='Adam', lr=2e-4, betas=(0.9, 0.999))) + +# learning policy +total_iters = 1000000 +lr_config = dict( + policy='CosineRestart', + by_epoch=False, + periods=[250000, 250000, 250000, 250000], + restart_weights=[1, 1, 1, 1], + min_lr=1e-7) + +checkpoint_config = dict(interval=5000, save_optimizer=True, by_epoch=False) +evaluation = dict(interval=5000, save_image=True, gpu_collect=True) +log_config = dict( + interval=100, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + dict(type='TensorboardLoggerHook'), + ]) +visual_config = None + +# runtime settings +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = f'./work_dirs/{exp_name}' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/tests/test_mmedit/test_mmedit_apis.py b/tests/test_mmedit/test_mmedit_apis.py new file mode 100644 index 000000000..592b60c4c --- /dev/null +++ b/tests/test_mmedit/test_mmedit_apis.py @@ -0,0 +1,219 @@ +import importlib +import os +import tempfile + +import mmcv +import numpy as np +import pytest +import torch + +import mmdeploy.apis.onnxruntime as ort_apis +import mmdeploy.apis.ppl as ppl_apis +import mmdeploy.apis.tensorrt as trt_apis +import mmdeploy.apis.test as api_test +import mmdeploy.apis.utils as api_utils +from mmdeploy.utils.constants import Backend, Codebase +from mmdeploy.utils.test import SwitchBackendWrapper + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.skipif( + not importlib.util.find_spec('tensorrt'), reason='requires tensorrt') +def test_TensorRTRestorer(): + # force add backend wrapper regardless of plugins + from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper + trt_apis.__dict__.update({'TRTWrapper': TRTWrapper}) + + # simplify backend inference + outputs = { + 'output': torch.rand(1, 3, 64, 64).cuda(), + } + + with SwitchBackendWrapper(TRTWrapper) as wrapper: + wrapper.set(outputs=outputs) + + from mmdeploy.mmedit.apis.inference import TensorRTRestorer + trt_restorer = TensorRTRestorer('', 0) + imgs = torch.rand(1, 3, 64, 64).cuda() + + results = trt_restorer.forward(imgs) + assert results is not None, ('failed to get output using ' + 'TensorRTRestorer') + + results = trt_restorer.forward(imgs, test_mode=True) + assert results is not None, ('failed to get output using ' + 'TensorRTRestorer') + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_ONNXRuntimeRestorer(): + # force add backend wrapper regardless of plugins + from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + outputs = torch.rand(1, 3, 64, 64) + + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(outputs=outputs) + + from mmdeploy.mmedit.apis.inference import ONNXRuntimeRestorer + ort_restorer = ONNXRuntimeRestorer('', 0) + imgs = torch.rand(1, 3, 64, 64) + + results = ort_restorer.forward(imgs) + assert results is not None, 'failed to get output using '\ + 'ONNXRuntimeRestorer' + + results = ort_restorer.forward(imgs, test_mode=True) + assert results is not None, 'failed to get output using '\ + 'ONNXRuntimeRestorer' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.skipif( + not importlib.util.find_spec('pyppl'), reason='requires pyppl') +def test_PPLRestorer(): + # force add backend wrapper regardless of plugins + from mmdeploy.apis.ppl.ppl_utils import PPLWrapper + ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper}) + + # simplify backend inference + outputs = torch.rand(1, 3, 64, 64) + + with SwitchBackendWrapper(PPLWrapper) as wrapper: + wrapper.set(outputs=outputs) + + from mmdeploy.mmedit.apis.inference import PPLRestorer + ppl_restorer = PPLRestorer('', 0) + imgs = torch.rand(1, 3, 64, 64) + + results = ppl_restorer.forward(imgs) + assert results is not None, 'failed to get output using PPLRestorer' + + results = ppl_restorer.forward(imgs, test_mode=True) + assert results is not None, 'failed to get output using PPLRestorer' + + +model_cfg = 'tests/test_mmedit/data/model.py' +deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmedit', task='SuperResolution'), + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + input_shape=None, + input_names=['input'], + output_names=['output']))) +input_img = torch.rand(1, 3, 64, 64) +input = {'lq': input_img} + + +def test_init_pytorch_model(): + model = api_utils.init_pytorch_model( + Codebase.MMEDIT, model_cfg=model_cfg, device='cpu') + assert model is not None + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def create_backend_model(): + from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + + wrapper = SwitchBackendWrapper(ORTWrapper) + wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg) + model = api_utils.init_backend_model([''], model_cfg, deploy_cfg) + + return model, wrapper + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_init_backend_model(): + model, wrapper = create_backend_model() + assert model is not None + + # Recovery + wrapper.recover() + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_run_inference(): + model, wrapper = create_backend_model() + result = api_utils.run_inference(Codebase.MMEDIT, input, model) + assert isinstance(result, np.ndarray) + + # Recovery + wrapper.recover() + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_visualize(): + model, wrapper = create_backend_model() + result = api_utils.run_inference(Codebase.MMEDIT, input, model) + with tempfile.TemporaryDirectory() as dir: + filename = dir + 'tmp.jpg' + api_utils.visualize(Codebase.MMEDIT, input, result, model, filename, + Backend.ONNXRUNTIME) + assert os.path.exists(filename) + + # Recovery + wrapper.recover() + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_inference_model(): + numpy_img = np.random.rand(64, 64, 3) + with tempfile.TemporaryDirectory() as dir: + filename = dir + 'tmp.jpg' + model, wrapper = create_backend_model() + from mmdeploy.apis.inference import inference_model + inference_model(model_cfg, deploy_cfg, model, numpy_img, 'cpu', + Backend.ONNXRUNTIME, filename, False) + assert os.path.exists(filename) + + # Recovery + wrapper.recover() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +def test_test(): + from mmcv.parallel import MMDataParallel + with tempfile.TemporaryDirectory() as dir: + + # Export a complete model + numpy_img = np.random.rand(50, 50, 3) + onnx_filename = 'end2end.onnx' + onnx_path = os.path.join(dir, onnx_filename) + from mmdeploy.apis import torch2onnx + torch2onnx(numpy_img, dir, onnx_filename, deploy_cfg, model_cfg) + assert os.path.exists(onnx_path) + + # Prepare dataloader + dataset = api_utils.build_dataset( + Codebase.MMEDIT, model_cfg, dataset_type='test') + assert dataset is not None, 'Failed to build dataset' + dataloader = api_utils.build_dataloader(Codebase.MMEDIT, dataset, 1, 1) + assert dataloader is not None, 'Failed to build dataloader' + + # Prepare model + model = api_utils.init_backend_model([onnx_path], model_cfg, + deploy_cfg) + model = MMDataParallel(model, device_ids=[0]) + assert model is not None + + # Run test + outputs = api_test.single_gpu_test(Codebase.MMEDIT, model, dataloader) + assert outputs is not None + api_test.post_process_outputs(outputs, dataset, model_cfg, + Codebase.MMEDIT) diff --git a/tests/test_mmedit/test_mmedit_export.py b/tests/test_mmedit/test_mmedit_export.py new file mode 100644 index 000000000..3c59de9e8 --- /dev/null +++ b/tests/test_mmedit/test_mmedit_export.py @@ -0,0 +1,97 @@ +import mmcv +import numpy as np + +from mmdeploy.apis.utils import build_dataloader, build_dataset, create_input +from mmdeploy.utils.constants import Codebase, Task + + +class TestCreateInput: + task = Task.SUPER_RESOLUTION + img_test_pipeline = [ + dict( + type='LoadImageFromFile', + io_backend='disk', + key='lq', + flag='unchanged'), + dict( + type='LoadImageFromFile', + io_backend='disk', + key='gt', + flag='unchanged'), + dict(type='RescaleToZeroOne', keys=['lq', 'gt']), + dict( + type='Normalize', + keys=['lq', 'gt'], + mean=[0, 0, 0], + std=[1, 1, 1], + to_rgb=True), + dict( + type='Collect', + keys=['lq', 'gt'], + meta_keys=['lq_path', 'lq_path']), + dict(type='ImageToTensor', keys=['lq', 'gt']) + ] + + imgs = np.random.rand(32, 32, 3) + img_path = 'tests/test_mmedit/data/imgs/blank.jpg' + + def test_create_input_static(this): + data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline)) + model_cfg = mmcv.Config( + dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline)) + inputs = create_input( + Codebase.MMEDIT, + TestCreateInput.task, + model_cfg, + TestCreateInput.imgs, + input_shape=(32, 32), + device='cpu') + assert inputs is not None, 'Failed to create input' + + def test_create_input_dynamic(this): + data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline)) + model_cfg = mmcv.Config( + dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline)) + inputs = create_input( + Codebase.MMEDIT, + TestCreateInput.task, + model_cfg, + TestCreateInput.imgs, + input_shape=None, + device='cpu') + assert inputs is not None, 'Failed to create input' + + def test_create_input_from_file(this): + data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline)) + model_cfg = mmcv.Config( + dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline)) + inputs = create_input( + Codebase.MMEDIT, + TestCreateInput.task, + model_cfg, + TestCreateInput.img_path, + input_shape=None, + device='cpu') + assert inputs is not None, 'Failed to create input' + + +def test_build_dataset(): + data = dict( + test={ + 'type': 'SRFolderDataset', + 'lq_folder': 'tests/test_mmedit/data/imgs', + 'gt_folder': 'tests/test_mmedit/data/imgs', + 'scale': 1, + 'filename_tmpl': '{}', + 'pipeline': [ + { + 'type': 'LoadImageFromFile' + }, + ] + }) + dataset_cfg = mmcv.Config(dict(data=data)) + dataset = build_dataset( + Codebase.MMEDIT, dataset_cfg=dataset_cfg, dataset_type='test') + assert dataset is not None, 'Failed to build dataset' + dataloader = build_dataloader(Codebase.MMEDIT, dataset, 1, 1) + assert dataloader is not None, 'Failed to build dataloader' diff --git a/tests/test_mmedit/test_mmedit_models.py b/tests/test_mmedit/test_mmedit_models.py new file mode 100644 index 000000000..67ddd5e60 --- /dev/null +++ b/tests/test_mmedit/test_mmedit_models.py @@ -0,0 +1,72 @@ +import os.path as osp +import tempfile + +import mmcv +import onnx +import torch +from mmedit.models.backbones.sr_backbones import SRCNN + +from mmdeploy.core import RewriterContext +from mmdeploy.utils import Backend, get_onnx_config + +img = torch.rand(1, 3, 4, 4) +model_file = tempfile.NamedTemporaryFile(suffix='.onnx').name + +deploy_cfg = mmcv.Config( + dict( + codebase_config=dict( + type='mmedit', + task='SuperResolution', + ), + backend_config=dict( + type='tensorrt', + common_config=dict(fp16_mode=False, max_workspace_size=1 << 10), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 4, 4], + opt_shape=[1, 3, 4, 4], + max_shape=[1, 3, 4, 4]))) + ]), + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file=model_file, + input_shape=None, + input_names=['input'], + output_names=['output']))) + + +def test_srcnn(): + pytorch_model = SRCNN() + model_inputs = {'x': img} + + onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name + pytorch2onnx_cfg = get_onnx_config(deploy_cfg) + input_names = [k for k, v in model_inputs.items() if k != 'ctx'] + with RewriterContext( + cfg=deploy_cfg, backend=Backend.TENSORRT.value), torch.no_grad(): + torch.onnx.export( + pytorch_model, + tuple([v for k, v in model_inputs.items()]), + onnx_file_path, + export_params=True, + input_names=input_names, + output_names=None, + opset_version=11, + dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None), + keep_initializers_as_inputs=False) + + # The result should be different due to the rewrite. + # So we only check if the file exists + assert osp.exists(onnx_file_path) + + model = onnx.load(onnx_file_path) + assert model is not None + try: + onnx.checker.check_model(model) + except onnx.checker.ValidationError: + assert False diff --git a/tests/test_ops/utils.py b/tests/test_ops/utils.py index fad5f3572..94b4fbf49 100644 --- a/tests/test_ops/utils.py +++ b/tests/test_ops/utils.py @@ -121,7 +121,7 @@ class TestTensorRTExporter: backend_config=dict( type='tensorrt', common_config=dict( - fp16_mode=False, max_workspace_size=1 << 30), + fp16_mode=False, max_workspace_size=1 << 28), model_inputs=[ dict( input_shapes=dict(