mmdeploy/mmdeploy/apis/inference.py

76 lines
2.8 KiB
Python

from typing import Optional, Sequence, Union
import mmcv
import numpy as np
import torch
from mmdeploy.utils import (Backend, get_backend, get_codebase,
get_input_shape, get_task_type, load_config)
from .utils import (create_input, init_backend_model, init_pytorch_model,
run_inference, visualize)
def inference_model(model_cfg: Union[str, mmcv.Config],
deploy_cfg: Union[str, mmcv.Config],
model: Union[str, Sequence[str], torch.nn.Module],
img: Union[str, np.ndarray],
device: str,
backend: Optional[Backend] = None,
output_file: Optional[str] = None,
show_result: bool = False):
"""Run inference with PyTorch or backend model and show results.
Args:
model_cfg (str | mmcv.Config): Model config file or Config object.
deploy_cfg (str | mmcv.Config): Deployment config file or Config
object.
model (str | list[str], torch.nn.Module): Input model or file(s).
img (str | np.ndarray): Input image file or numpy array for inference.
device (str): A string specifying device type.
backend (Backend): Specifying backend type, defaults to `None`.
output_file (str): Output file to save visualized image, defaults to
`None`. Only valid if `show_result` is set to `False`.
show_result (bool): Whether to show plotted image in windows, defaults
to `False`.
"""
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
codebase = get_codebase(deploy_cfg)
task = get_task_type(deploy_cfg)
input_shape = get_input_shape(deploy_cfg)
if backend is None:
backend = get_backend(deploy_cfg)
if isinstance(model, str):
model = [model]
if isinstance(model, (list, tuple)):
assert len(model) > 0, 'Model should have at least one element.'
assert all([isinstance(m, str) for m in model]), 'All elements in the \
list should be str'
if backend == Backend.PYTORCH:
model = init_pytorch_model(codebase, model_cfg, model[0], device)
else:
device_id = -1 if device == 'cpu' else 0
model = init_backend_model(
model,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
device_id=device_id)
model_inputs, _ = create_input(codebase, task, model_cfg, img, input_shape,
device)
with torch.no_grad():
result = run_inference(codebase, model_inputs, model)
visualize(
codebase,
img,
result=result,
model=model,
output_file=output_file,
backend=backend,
show_result=show_result)