mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add device backend check (#886)
* add device backend check * safe check * only activated for tensorrt and openvino * resolve comments
This commit is contained in:
parent
3fa15822b1
commit
9fbfdd2178
@ -2,7 +2,25 @@
|
|||||||
import mmcv
|
import mmcv
|
||||||
|
|
||||||
from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
|
from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
|
||||||
from mmdeploy.utils import get_codebase, get_task_type
|
from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
|
||||||
|
parse_device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def check_backend_device(deploy_cfg: mmcv.Config, device: str):
|
||||||
|
"""Check if device is appropriate for the backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deploy_cfg (str | mmcv.Config): Deployment config file.
|
||||||
|
device (str): A string specifying device type.
|
||||||
|
"""
|
||||||
|
backend = get_backend(deploy_cfg).value
|
||||||
|
device_id = parse_device_id(device)
|
||||||
|
mismatch = dict(
|
||||||
|
tensorrt=lambda id: id == -1,
|
||||||
|
openvino=lambda id: id > -1,
|
||||||
|
)
|
||||||
|
if backend in mismatch and mismatch[backend](device_id):
|
||||||
|
raise ValueError(f'{device} is invalid for the backend {backend}')
|
||||||
|
|
||||||
|
|
||||||
def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
|
def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
|
||||||
@ -17,6 +35,7 @@ def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
|
|||||||
Returns:
|
Returns:
|
||||||
BaseTask: A task processor.
|
BaseTask: A task processor.
|
||||||
"""
|
"""
|
||||||
|
check_backend_device(deploy_cfg=deploy_cfg, device=device)
|
||||||
codebase_type = get_codebase(deploy_cfg)
|
codebase_type = get_codebase(deploy_cfg)
|
||||||
import_codebase(codebase_type)
|
import_codebase(codebase_type)
|
||||||
codebase = get_codebase_class(codebase_type)
|
codebase = get_codebase_class(codebase_type)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user