add visualization for mmdet (#13)

* add visualization for mmdet

* resolve comments

* update code, enable visualize on nvidia driver>11

Co-authored-by: grimoire <yaoqian@sensetime.com>
This commit is contained in:
RunningLeon 2021-07-13 17:21:02 +08:00 committed by GitHub
parent 0eca9ebcbf
commit c4b7dad2ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 216 additions and 26 deletions

View File

@ -1,4 +1,5 @@
from .pytorch2onnx import torch2onnx, torch2onnx_impl from .pytorch2onnx import torch2onnx, torch2onnx_impl
from .extract_model import extract_model from .extract_model import extract_model
from .inference import inference_model
__all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model'] __all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model']

View File

@ -0,0 +1,47 @@
from typing import Optional
from .utils import (check_model_outputs, create_input, get_classes_from_config,
init_backend_model, init_model)
import torch.multiprocessing as mp
def inference_model(model_cfg,
model,
img,
codebase: str,
backend: str,
device: str,
output_file: Optional[str] = None,
show_result=False,
ret_value: Optional[mp.Value] = None):
if ret_value is not None:
ret_value.value = -1
if isinstance(model, str):
model = [model]
if isinstance(model, (list, tuple)):
if backend == 'pytorch':
model = init_model(codebase, model_cfg, model[0], device)
else:
device_id = -1 if device == 'cpu' else 0
model = init_backend_model(
model,
codebase=codebase,
backend=backend,
class_names=get_classes_from_config(codebase, model_cfg),
device_id=device_id)
model_inputs, _ = create_input(codebase, model_cfg, img, device)
check_model_outputs(
codebase,
img,
model_inputs=model_inputs,
model=model,
output_file=output_file,
backend=backend,
show_result=show_result)
if ret_value is not None:
ret_value.value = 0

View File

@ -1,7 +1,9 @@
import importlib import importlib
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Sequence, Union
import mmcv import mmcv
import numpy as np
import torch
def module_exist(module_name: str): def module_exist(module_name: str):
@ -76,3 +78,105 @@ def attribute_to_dict(attr):
value = str(value, 'utf-8') value = str(value, 'utf-8')
ret[a.name] = value ret[a.name] = value
return ret return ret
def init_backend_model(model_files: Sequence[str],
codebase: str,
backend: str,
class_names: Sequence[str],
device_id: int = 0):
if codebase == 'mmcls':
if module_exist(codebase):
raise NotImplementedError(f'Unsupported codebase type: {codebase}')
else:
raise ImportError(f'Can not import module: {codebase}')
elif codebase == 'mmdet':
if module_exist(codebase):
if backend == 'onnxruntime':
from mmdeploy.mmdet.export import ONNXRuntimeDetector
backend_model = ONNXRuntimeDetector(
model_files[0],
class_names=class_names,
device_id=device_id)
elif backend == 'tensorrt':
from mmdeploy.mmdet.export import TensorRTDetector
backend_model = TensorRTDetector(
model_files[0],
class_names=class_names,
device_id=device_id)
else:
raise NotImplementedError(
f'Unsupported backend type: {backend}')
return backend_model
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
def get_classes_from_config(
codebase: str,
model_cfg: Union[str, mmcv.Config],
):
model_cfg_str = model_cfg
if codebase == 'mmdet':
if module_exist(codebase):
if isinstance(model_cfg, str):
model_cfg = mmcv.Config.fromfile(model_cfg)
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
raise TypeError('config must be a filename or Config object, '
f'but got {type(model_cfg)}')
from mmdet.datasets import DATASETS
module_dict = DATASETS.module_dict
data_cfg = model_cfg.data
if 'train' in data_cfg:
module = module_dict[data_cfg.train.type]
elif 'val' in data_cfg:
module = module_dict[data_cfg.val.type]
elif 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
else:
raise RuntimeError(
f'No dataset config found in: {model_cfg_str}')
return module.CLASSES
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
def check_model_outputs(codebase: str,
image: Union[str, np.ndarray],
model_inputs,
model,
output_file: str,
backend: str,
show_result=False):
show_img = mmcv.imread(image) if isinstance(image, str) else image
if codebase == 'mmcls':
if module_exist(codebase):
raise NotImplementedError(f'Unsupported codebase type: {codebase}')
else:
raise ImportError(f'Can not import module: {codebase}')
elif codebase == 'mmdet':
if module_exist(codebase):
output_file = None if show_result else output_file
score_thr = 0.3
with torch.no_grad():
results = model(
**model_inputs, return_loss=False, rescale=True)[0]
model.show_result(
show_img,
results,
score_thr=score_thr,
show=True,
win_name=backend,
out_file=output_file)
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')

View File

@ -7,6 +7,7 @@ import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method from torch.multiprocessing import Process, set_start_method
from mmdeploy.apis import torch2onnx from mmdeploy.apis import torch2onnx
from mmdeploy.apis import inference_model
def parse_args(): def parse_args():
@ -24,11 +25,27 @@ def parse_args():
help='set log level', help='set log level',
default='INFO', default='INFO',
choices=list(logging._nameToLevel.keys())) choices=list(logging._nameToLevel.keys()))
parser.add_argument(
'--show', action='store_true', help='Show detection outputs')
args = parser.parse_args() args = parser.parse_args()
return args return args
def create_process(name, target, args, kwargs, ret_value=None):
logging.info(f'start {name}.')
process = Process(target=target, args=args, kwargs=kwargs)
process.start()
process.join()
if ret_value is not None:
if ret_value.value != 0:
logging.error(f'{name} failed.')
exit()
else:
logging.info(f'{name} success.')
def main(): def main():
args = parse_args() args = parse_args()
set_start_method('spawn') set_start_method('spawn')
@ -52,51 +69,72 @@ def main():
ret_value = mp.Value('d', 0, lock=False) ret_value = mp.Value('d', 0, lock=False)
# convert onnx # convert onnx
logging.info('start torch2onnx conversion.')
onnx_save_file = deploy_cfg['pytorch2onnx']['save_file'] onnx_save_file = deploy_cfg['pytorch2onnx']['save_file']
process = Process( create_process(
'torch2onnx',
target=torch2onnx, target=torch2onnx,
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path, args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
model_cfg_path, checkpoint_path), model_cfg_path, checkpoint_path),
kwargs=dict(device=args.device, ret_value=ret_value)) kwargs=dict(device=args.device, ret_value=ret_value),
process.start() ret_value=ret_value)
process.join()
if ret_value.value != 0:
logging.error('torch2onnx failed.')
exit()
else:
logging.info('torch2onnx success.')
# convert backend # convert backend
onnx_pathes = [osp.join(args.work_dir, onnx_save_file)] onnx_files = [osp.join(args.work_dir, onnx_save_file)]
backend_files = onnx_files
backend = deploy_cfg.get('backend', 'default') backend = deploy_cfg.get('backend', 'default')
if backend == 'tensorrt': if backend == 'tensorrt':
assert hasattr(deploy_cfg, 'tensorrt_params') assert hasattr(deploy_cfg, 'tensorrt_params')
tensorrt_params = deploy_cfg['tensorrt_params'] tensorrt_params = deploy_cfg['tensorrt_params']
model_params = tensorrt_params.get('model_params', []) model_params = tensorrt_params.get('model_params', [])
assert len(model_params) == len(onnx_pathes) assert len(model_params) == len(onnx_files)
logging.info('start onnx2tensorrt conversion.')
from mmdeploy.apis.tensorrt import onnx2tensorrt from mmdeploy.apis.tensorrt import onnx2tensorrt
backend_files = []
for model_id, model_param, onnx_path in zip( for model_id, model_param, onnx_path in zip(
range(len(onnx_pathes)), model_params, onnx_pathes): range(len(onnx_files)), model_params, onnx_files):
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0] onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
save_file = model_param.get('save_file', onnx_name + '.engine') save_file = model_param.get('save_file', onnx_name + '.engine')
process = Process(
create_process(
f'onnx2tensorrt of {onnx_path}',
target=onnx2tensorrt, target=onnx2tensorrt,
args=(args.work_dir, save_file, model_id, deploy_cfg_path, args=(args.work_dir, save_file, model_id, deploy_cfg_path,
onnx_path), onnx_path),
kwargs=dict(device=args.device, ret_value=ret_value)) kwargs=dict(device=args.device, ret_value=ret_value),
process.start() ret_value=ret_value)
process.join()
if ret_value.value != 0: backend_files.append(osp.join(args.work_dir, save_file))
logging.error('onnx2tensorrt failed.')
exit() # check model outputs by visualization
else: codebase = deploy_cfg['codebase']
logging.info('onnx2tensorrt success.')
# visualize tensorrt model
create_process(
f'visualize {backend} model',
target=inference_model,
args=(model_cfg_path, backend_files, args.img),
kwargs=dict(
codebase=codebase,
backend=backend,
device=args.device,
output_file=f'output_{backend}.jpg',
show_result=args.show,
ret_value=ret_value))
# visualize pytorch model
create_process(
'visualize pytorch model',
target=inference_model,
args=(model_cfg_path, [checkpoint_path], args.img),
kwargs=dict(
codebase=codebase,
backend='pytorch',
device=args.device,
output_file='output_pytorch.jpg',
show_result=args.show,
ret_value=ret_value),
ret_value=ret_value)
logging.info('All process success.') logging.info('All process success.')