add device backend check (#886)

* add device backend check

* safe check

* only activated for tensorrt and openvino

* resolve comments
This commit is contained in:
AllentDan 2022-08-16 17:20:29 +08:00 committed by GitHub
parent 3fa15822b1
commit 9fbfdd2178
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,25 @@
import mmcv
from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
from mmdeploy.utils import get_codebase, get_task_type
from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
parse_device_id)
def check_backend_device(deploy_cfg: mmcv.Config, device: str):
"""Check if device is appropriate for the backend.
Args:
deploy_cfg (str | mmcv.Config): Deployment config file.
device (str): A string specifying device type.
"""
backend = get_backend(deploy_cfg).value
device_id = parse_device_id(device)
mismatch = dict(
tensorrt=lambda id: id == -1,
openvino=lambda id: id > -1,
)
if backend in mismatch and mismatch[backend](device_id):
raise ValueError(f'{device} is invalid for the backend {backend}')
def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
@ -17,6 +35,7 @@ def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
Returns:
BaseTask: A task processor.
"""
check_backend_device(deploy_cfg=deploy_cfg, device=device)
codebase_type = get_codebase(deploy_cfg)
import_codebase(codebase_type)
codebase = get_codebase_class(codebase_type)