[Enhancement] Support ncnn vulkan (#318)

* support ncnn_vulkan

* fix some comments

* avoid bc breaking

* add default value of config
pull/371/head
hanrui1sensetime 2022-04-14 22:12:17 +08:00 committed by GitHub
parent e96077feea
commit c54d574a10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 2 deletions

View File

@ -38,12 +38,14 @@ class NCNNWrapper(BaseWrapper):
param_file: str,
bin_file: str,
output_names: Optional[Sequence[str]] = None,
use_vulkan: bool = False,
**kwargs):
net = ncnn.Net()
if importlib.util.find_spec('mmdeploy.backend.ncnn.ncnn_ext'):
from mmdeploy.backend.ncnn import ncnn_ext
ncnn_ext.register_mmdeploy_custom_layers(net)
net.opt.use_vulkan_compute = use_vulkan
net.load_param(param_file)
net.load_model(bin_file)

View File

@ -5,7 +5,8 @@ from typing import Optional, Sequence, Union
import mmcv
import torch
from mmdeploy.utils import SDK_TASK_MAP, Backend, get_ir_config, get_task_type
from mmdeploy.utils import (SDK_TASK_MAP, Backend, get_backend_config,
get_ir_config, get_task_type)
class BaseBackendModel(torch.nn.Module, metaclass=ABCMeta):
@ -73,10 +74,12 @@ class BaseBackendModel(torch.nn.Module, metaclass=ABCMeta):
output_names=output_names)
elif backend == Backend.NCNN:
from mmdeploy.backend.ncnn import NCNNWrapper
use_vulkan = get_backend_config('use_vulkan', False)
return NCNNWrapper(
param_file=backend_files[0],
bin_file=backend_files[1],
output_names=output_names)
output_names=output_names,
use_vulkan=use_vulkan)
elif backend == Backend.OPENVINO:
from mmdeploy.backend.openvino import OpenVINOWrapper
return OpenVINOWrapper(