From c54d574a1015aaedb2863302896cb9eb5bbb8e3b Mon Sep 17 00:00:00 2001 From: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Date: Thu, 14 Apr 2022 22:12:17 +0800 Subject: [PATCH] [Enhancement] Support ncnn vulkan (#318) * support ncnn_vulkan * fix some comments * avoid bc breaking * add default value of config --- mmdeploy/backend/ncnn/wrapper.py | 2 ++ mmdeploy/codebase/base/backend_model.py | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mmdeploy/backend/ncnn/wrapper.py b/mmdeploy/backend/ncnn/wrapper.py index c1b0c8063..2fa1260ef 100644 --- a/mmdeploy/backend/ncnn/wrapper.py +++ b/mmdeploy/backend/ncnn/wrapper.py @@ -38,12 +38,14 @@ class NCNNWrapper(BaseWrapper): param_file: str, bin_file: str, output_names: Optional[Sequence[str]] = None, + use_vulkan: bool = False, **kwargs): net = ncnn.Net() if importlib.util.find_spec('mmdeploy.backend.ncnn.ncnn_ext'): from mmdeploy.backend.ncnn import ncnn_ext ncnn_ext.register_mmdeploy_custom_layers(net) + net.opt.use_vulkan_compute = use_vulkan net.load_param(param_file) net.load_model(bin_file) diff --git a/mmdeploy/codebase/base/backend_model.py b/mmdeploy/codebase/base/backend_model.py index 93dc2fe74..2db2cc3c5 100644 --- a/mmdeploy/codebase/base/backend_model.py +++ b/mmdeploy/codebase/base/backend_model.py @@ -5,7 +5,8 @@ from typing import Optional, Sequence, Union import mmcv import torch -from mmdeploy.utils import SDK_TASK_MAP, Backend, get_ir_config, get_task_type +from mmdeploy.utils import (SDK_TASK_MAP, Backend, get_backend_config, + get_ir_config, get_task_type) class BaseBackendModel(torch.nn.Module, metaclass=ABCMeta): @@ -73,10 +74,12 @@ class BaseBackendModel(torch.nn.Module, metaclass=ABCMeta): output_names=output_names) elif backend == Backend.NCNN: from mmdeploy.backend.ncnn import NCNNWrapper + use_vulkan = get_backend_config('use_vulkan', False) return NCNNWrapper( param_file=backend_files[0], bin_file=backend_files[1], - output_names=output_names) + output_names=output_names, + use_vulkan=use_vulkan) elif backend == Backend.OPENVINO: from mmdeploy.backend.openvino import OpenVINOWrapper return OpenVINOWrapper(