mmdeploy/mmdeploy/apis/visualize.py

94 lines
3.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union
import mmengine
import numpy as np
import torch
from mmdeploy.utils import Backend, get_backend, get_input_shape, load_config
def visualize_model(model_cfg: Union[str, mmengine.Config],
deploy_cfg: Union[str, mmengine.Config],
model: Union[str, Sequence[str]],
img: Union[str, np.ndarray],
device: str,
backend: Optional[Backend] = None,
output_file: Optional[str] = None,
show_result: bool = False,
**kwargs):
"""Run inference with PyTorch or backend model and show results.
Examples:
>>> from mmdeploy.apis import visualize_model
>>> model_cfg = ('mmdetection/configs/fcos/'
'fcos_r50_caffe_fpn_gn-head_1x_coco.py')
>>> deploy_cfg = ('configs/mmdet/detection/'
'detection_onnxruntime_dynamic.py')
>>> model = 'work_dir/fcos.onnx'
>>> img = 'demo.jpg'
>>> device = 'cpu'
>>> visualize_model(model_cfg, deploy_cfg, model, \
img, device, show_result=True)
Args:
model_cfg (str | mmengine.Config): Model config file or Config object.
deploy_cfg (str | mmengine.Config): Deployment config file or Config
object.
model (str | list[str], BaseSubtask): Input model or file(s).
img (str | np.ndarray): Input image file or numpy array for inference.
device (str): A string specifying device type.
backend (Backend): Specifying backend type, defaults to `None`.
output_file (str): Output file to save visualized image, defaults to
`None`. Only valid if `show_result` is set to `False`.
show_result (bool): Whether to show plotted image in windows, defaults
to `False`.
"""
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
from mmdeploy.apis.utils import build_task_processor
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
input_shape = get_input_shape(deploy_cfg)
if backend is None:
backend = get_backend(deploy_cfg)
if isinstance(model, str):
model = [model]
if isinstance(model, (list, tuple)):
assert len(model) > 0, 'Model should have at least one element.'
if backend == Backend.PYTORCH:
model = task_processor.build_pytorch_model(model[0])
else:
model = task_processor.build_backend_model(
model,
data_preprocessor_updater=task_processor.
update_data_preprocessor)
model_inputs, _ = task_processor.create_input(img, input_shape)
with torch.no_grad():
result = model.test_step(model_inputs)[0]
visualize = True
try:
# check headless
import tkinter
tkinter.Tk()
except Exception as e:
from mmdeploy.utils import get_root_logger
logger = get_root_logger()
logger.warning(
f'render and display result skipped for headless device, exception {e}' # noqa: E501
)
visualize = False
if visualize is True:
task_processor.visualize(
image=img,
model=model,
result=result,
output_file=output_file,
window_name=backend.value,
show_result=show_result)