mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add device backend check (#886)
* add device backend check * safe check * only activated for tensorrt and openvino * resolve comments
This commit is contained in:
parent
3fa15822b1
commit
9fbfdd2178
@ -2,7 +2,25 @@
|
||||
import mmcv
|
||||
|
||||
from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
|
||||
from mmdeploy.utils import get_codebase, get_task_type
|
||||
from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
|
||||
parse_device_id)
|
||||
|
||||
|
||||
def check_backend_device(deploy_cfg: mmcv.Config, device: str):
|
||||
"""Check if device is appropriate for the backend.
|
||||
|
||||
Args:
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file.
|
||||
device (str): A string specifying device type.
|
||||
"""
|
||||
backend = get_backend(deploy_cfg).value
|
||||
device_id = parse_device_id(device)
|
||||
mismatch = dict(
|
||||
tensorrt=lambda id: id == -1,
|
||||
openvino=lambda id: id > -1,
|
||||
)
|
||||
if backend in mismatch and mismatch[backend](device_id):
|
||||
raise ValueError(f'{device} is invalid for the backend {backend}')
|
||||
|
||||
|
||||
def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
|
||||
@ -17,6 +35,7 @@ def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
|
||||
Returns:
|
||||
BaseTask: A task processor.
|
||||
"""
|
||||
check_backend_device(deploy_cfg=deploy_cfg, device=device)
|
||||
codebase_type = get_codebase(deploy_cfg)
|
||||
import_codebase(codebase_type)
|
||||
codebase = get_codebase_class(codebase_type)
|
||||
|
Loading…
x
Reference in New Issue
Block a user