mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add visualization for mmdet (#13)
* add visualization for mmdet * resolve comments * update code, enable visualize on nvidia driver>11 Co-authored-by: grimoire <yaoqian@sensetime.com>
This commit is contained in:
parent
0eca9ebcbf
commit
c4b7dad2ec
@ -1,4 +1,5 @@
|
||||
from .pytorch2onnx import torch2onnx, torch2onnx_impl
|
||||
from .extract_model import extract_model
|
||||
from .inference import inference_model
|
||||
|
||||
__all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model']
|
||||
__all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model']
|
||||
|
47
mmdeploy/apis/inference.py
Normal file
47
mmdeploy/apis/inference.py
Normal file
@ -0,0 +1,47 @@
|
||||
from typing import Optional
|
||||
|
||||
from .utils import (check_model_outputs, create_input, get_classes_from_config,
|
||||
init_backend_model, init_model)
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def inference_model(model_cfg,
|
||||
model,
|
||||
img,
|
||||
codebase: str,
|
||||
backend: str,
|
||||
device: str,
|
||||
output_file: Optional[str] = None,
|
||||
show_result=False,
|
||||
ret_value: Optional[mp.Value] = None):
|
||||
|
||||
if ret_value is not None:
|
||||
ret_value.value = -1
|
||||
|
||||
if isinstance(model, str):
|
||||
model = [model]
|
||||
if isinstance(model, (list, tuple)):
|
||||
if backend == 'pytorch':
|
||||
model = init_model(codebase, model_cfg, model[0], device)
|
||||
else:
|
||||
device_id = -1 if device == 'cpu' else 0
|
||||
model = init_backend_model(
|
||||
model,
|
||||
codebase=codebase,
|
||||
backend=backend,
|
||||
class_names=get_classes_from_config(codebase, model_cfg),
|
||||
device_id=device_id)
|
||||
|
||||
model_inputs, _ = create_input(codebase, model_cfg, img, device)
|
||||
|
||||
check_model_outputs(
|
||||
codebase,
|
||||
img,
|
||||
model_inputs=model_inputs,
|
||||
model=model,
|
||||
output_file=output_file,
|
||||
backend=backend,
|
||||
show_result=show_result)
|
||||
|
||||
if ret_value is not None:
|
||||
ret_value.value = 0
|
@ -1,7 +1,9 @@
|
||||
import importlib
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def module_exist(module_name: str):
|
||||
@ -76,3 +78,105 @@ def attribute_to_dict(attr):
|
||||
value = str(value, 'utf-8')
|
||||
ret[a.name] = value
|
||||
return ret
|
||||
|
||||
|
||||
def init_backend_model(model_files: Sequence[str],
|
||||
codebase: str,
|
||||
backend: str,
|
||||
class_names: Sequence[str],
|
||||
device_id: int = 0):
|
||||
if codebase == 'mmcls':
|
||||
if module_exist(codebase):
|
||||
raise NotImplementedError(f'Unsupported codebase type: {codebase}')
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
elif codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
if backend == 'onnxruntime':
|
||||
from mmdeploy.mmdet.export import ONNXRuntimeDetector
|
||||
backend_model = ONNXRuntimeDetector(
|
||||
model_files[0],
|
||||
class_names=class_names,
|
||||
device_id=device_id)
|
||||
elif backend == 'tensorrt':
|
||||
from mmdeploy.mmdet.export import TensorRTDetector
|
||||
backend_model = TensorRTDetector(
|
||||
model_files[0],
|
||||
class_names=class_names,
|
||||
device_id=device_id)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Unsupported backend type: {backend}')
|
||||
return backend_model
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
|
||||
def get_classes_from_config(
|
||||
codebase: str,
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
):
|
||||
model_cfg_str = model_cfg
|
||||
if codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
|
||||
from mmdet.datasets import DATASETS
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'No dataset config found in: {model_cfg_str}')
|
||||
|
||||
return module.CLASSES
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
|
||||
def check_model_outputs(codebase: str,
|
||||
image: Union[str, np.ndarray],
|
||||
model_inputs,
|
||||
model,
|
||||
output_file: str,
|
||||
backend: str,
|
||||
show_result=False):
|
||||
show_img = mmcv.imread(image) if isinstance(image, str) else image
|
||||
if codebase == 'mmcls':
|
||||
if module_exist(codebase):
|
||||
raise NotImplementedError(f'Unsupported codebase type: {codebase}')
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
elif codebase == 'mmdet':
|
||||
if module_exist(codebase):
|
||||
output_file = None if show_result else output_file
|
||||
score_thr = 0.3
|
||||
with torch.no_grad():
|
||||
results = model(
|
||||
**model_inputs, return_loss=False, rescale=True)[0]
|
||||
model.show_result(
|
||||
show_img,
|
||||
results,
|
||||
score_thr=score_thr,
|
||||
show=True,
|
||||
win_name=backend,
|
||||
out_file=output_file)
|
||||
|
||||
else:
|
||||
raise ImportError(f'Can not import module: {codebase}')
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
@ -7,6 +7,7 @@ import torch.multiprocessing as mp
|
||||
from torch.multiprocessing import Process, set_start_method
|
||||
|
||||
from mmdeploy.apis import torch2onnx
|
||||
from mmdeploy.apis import inference_model
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -24,11 +25,27 @@ def parse_args():
|
||||
help='set log level',
|
||||
default='INFO',
|
||||
choices=list(logging._nameToLevel.keys()))
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='Show detection outputs')
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def create_process(name, target, args, kwargs, ret_value=None):
|
||||
logging.info(f'start {name}.')
|
||||
process = Process(target=target, args=args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
if ret_value is not None:
|
||||
if ret_value.value != 0:
|
||||
logging.error(f'{name} failed.')
|
||||
exit()
|
||||
else:
|
||||
logging.info(f'{name} success.')
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
set_start_method('spawn')
|
||||
@ -52,51 +69,72 @@ def main():
|
||||
ret_value = mp.Value('d', 0, lock=False)
|
||||
|
||||
# convert onnx
|
||||
logging.info('start torch2onnx conversion.')
|
||||
onnx_save_file = deploy_cfg['pytorch2onnx']['save_file']
|
||||
process = Process(
|
||||
create_process(
|
||||
'torch2onnx',
|
||||
target=torch2onnx,
|
||||
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
|
||||
model_cfg_path, checkpoint_path),
|
||||
kwargs=dict(device=args.device, ret_value=ret_value))
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
if ret_value.value != 0:
|
||||
logging.error('torch2onnx failed.')
|
||||
exit()
|
||||
else:
|
||||
logging.info('torch2onnx success.')
|
||||
kwargs=dict(device=args.device, ret_value=ret_value),
|
||||
ret_value=ret_value)
|
||||
|
||||
# convert backend
|
||||
onnx_pathes = [osp.join(args.work_dir, onnx_save_file)]
|
||||
onnx_files = [osp.join(args.work_dir, onnx_save_file)]
|
||||
backend_files = onnx_files
|
||||
|
||||
backend = deploy_cfg.get('backend', 'default')
|
||||
if backend == 'tensorrt':
|
||||
assert hasattr(deploy_cfg, 'tensorrt_params')
|
||||
tensorrt_params = deploy_cfg['tensorrt_params']
|
||||
model_params = tensorrt_params.get('model_params', [])
|
||||
assert len(model_params) == len(onnx_pathes)
|
||||
assert len(model_params) == len(onnx_files)
|
||||
|
||||
logging.info('start onnx2tensorrt conversion.')
|
||||
from mmdeploy.apis.tensorrt import onnx2tensorrt
|
||||
backend_files = []
|
||||
for model_id, model_param, onnx_path in zip(
|
||||
range(len(onnx_pathes)), model_params, onnx_pathes):
|
||||
range(len(onnx_files)), model_params, onnx_files):
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
save_file = model_param.get('save_file', onnx_name + '.engine')
|
||||
process = Process(
|
||||
|
||||
create_process(
|
||||
f'onnx2tensorrt of {onnx_path}',
|
||||
target=onnx2tensorrt,
|
||||
args=(args.work_dir, save_file, model_id, deploy_cfg_path,
|
||||
onnx_path),
|
||||
kwargs=dict(device=args.device, ret_value=ret_value))
|
||||
process.start()
|
||||
process.join()
|
||||
kwargs=dict(device=args.device, ret_value=ret_value),
|
||||
ret_value=ret_value)
|
||||
|
||||
if ret_value.value != 0:
|
||||
logging.error('onnx2tensorrt failed.')
|
||||
exit()
|
||||
else:
|
||||
logging.info('onnx2tensorrt success.')
|
||||
backend_files.append(osp.join(args.work_dir, save_file))
|
||||
|
||||
# check model outputs by visualization
|
||||
codebase = deploy_cfg['codebase']
|
||||
|
||||
# visualize tensorrt model
|
||||
create_process(
|
||||
f'visualize {backend} model',
|
||||
target=inference_model,
|
||||
args=(model_cfg_path, backend_files, args.img),
|
||||
kwargs=dict(
|
||||
codebase=codebase,
|
||||
backend=backend,
|
||||
device=args.device,
|
||||
output_file=f'output_{backend}.jpg',
|
||||
show_result=args.show,
|
||||
ret_value=ret_value))
|
||||
|
||||
# visualize pytorch model
|
||||
create_process(
|
||||
'visualize pytorch model',
|
||||
target=inference_model,
|
||||
args=(model_cfg_path, [checkpoint_path], args.img),
|
||||
kwargs=dict(
|
||||
codebase=codebase,
|
||||
backend='pytorch',
|
||||
device=args.device,
|
||||
output_file='output_pytorch.jpg',
|
||||
show_result=args.show,
|
||||
ret_value=ret_value),
|
||||
ret_value=ret_value)
|
||||
|
||||
logging.info('All process success.')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user