add visualization for mmdet (#13)

* add visualization for mmdet

* resolve comments

* update code, enable visualize on nvidia driver>11

Co-authored-by: grimoire <yaoqian@sensetime.com>
This commit is contained in:
RunningLeon 2021-07-13 17:21:02 +08:00 committed by GitHub
parent 0eca9ebcbf
commit c4b7dad2ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 216 additions and 26 deletions

View File

@ -1,4 +1,5 @@
from .pytorch2onnx import torch2onnx, torch2onnx_impl
from .extract_model import extract_model
from .inference import inference_model
__all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model']
__all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model']

View File

@ -0,0 +1,47 @@
from typing import Optional
from .utils import (check_model_outputs, create_input, get_classes_from_config,
init_backend_model, init_model)
import torch.multiprocessing as mp
def inference_model(model_cfg,
model,
img,
codebase: str,
backend: str,
device: str,
output_file: Optional[str] = None,
show_result=False,
ret_value: Optional[mp.Value] = None):
if ret_value is not None:
ret_value.value = -1
if isinstance(model, str):
model = [model]
if isinstance(model, (list, tuple)):
if backend == 'pytorch':
model = init_model(codebase, model_cfg, model[0], device)
else:
device_id = -1 if device == 'cpu' else 0
model = init_backend_model(
model,
codebase=codebase,
backend=backend,
class_names=get_classes_from_config(codebase, model_cfg),
device_id=device_id)
model_inputs, _ = create_input(codebase, model_cfg, img, device)
check_model_outputs(
codebase,
img,
model_inputs=model_inputs,
model=model,
output_file=output_file,
backend=backend,
show_result=show_result)
if ret_value is not None:
ret_value.value = 0

View File

@ -1,7 +1,9 @@
import importlib
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union
import mmcv
import numpy as np
import torch
def module_exist(module_name: str):
@ -76,3 +78,105 @@ def attribute_to_dict(attr):
value = str(value, 'utf-8')
ret[a.name] = value
return ret
def init_backend_model(model_files: Sequence[str],
codebase: str,
backend: str,
class_names: Sequence[str],
device_id: int = 0):
if codebase == 'mmcls':
if module_exist(codebase):
raise NotImplementedError(f'Unsupported codebase type: {codebase}')
else:
raise ImportError(f'Can not import module: {codebase}')
elif codebase == 'mmdet':
if module_exist(codebase):
if backend == 'onnxruntime':
from mmdeploy.mmdet.export import ONNXRuntimeDetector
backend_model = ONNXRuntimeDetector(
model_files[0],
class_names=class_names,
device_id=device_id)
elif backend == 'tensorrt':
from mmdeploy.mmdet.export import TensorRTDetector
backend_model = TensorRTDetector(
model_files[0],
class_names=class_names,
device_id=device_id)
else:
raise NotImplementedError(
f'Unsupported backend type: {backend}')
return backend_model
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
def get_classes_from_config(
codebase: str,
model_cfg: Union[str, mmcv.Config],
):
model_cfg_str = model_cfg
if codebase == 'mmdet':
if module_exist(codebase):
if isinstance(model_cfg, str):
model_cfg = mmcv.Config.fromfile(model_cfg)
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
raise TypeError('config must be a filename or Config object, '
f'but got {type(model_cfg)}')
from mmdet.datasets import DATASETS
module_dict = DATASETS.module_dict
data_cfg = model_cfg.data
if 'train' in data_cfg:
module = module_dict[data_cfg.train.type]
elif 'val' in data_cfg:
module = module_dict[data_cfg.val.type]
elif 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
else:
raise RuntimeError(
f'No dataset config found in: {model_cfg_str}')
return module.CLASSES
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')
def check_model_outputs(codebase: str,
image: Union[str, np.ndarray],
model_inputs,
model,
output_file: str,
backend: str,
show_result=False):
show_img = mmcv.imread(image) if isinstance(image, str) else image
if codebase == 'mmcls':
if module_exist(codebase):
raise NotImplementedError(f'Unsupported codebase type: {codebase}')
else:
raise ImportError(f'Can not import module: {codebase}')
elif codebase == 'mmdet':
if module_exist(codebase):
output_file = None if show_result else output_file
score_thr = 0.3
with torch.no_grad():
results = model(
**model_inputs, return_loss=False, rescale=True)[0]
model.show_result(
show_img,
results,
score_thr=score_thr,
show=True,
win_name=backend,
out_file=output_file)
else:
raise ImportError(f'Can not import module: {codebase}')
else:
raise NotImplementedError(f'Unknown codebase type: {codebase}')

View File

@ -7,6 +7,7 @@ import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method
from mmdeploy.apis import torch2onnx
from mmdeploy.apis import inference_model
def parse_args():
@ -24,11 +25,27 @@ def parse_args():
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
parser.add_argument(
'--show', action='store_true', help='Show detection outputs')
args = parser.parse_args()
return args
def create_process(name, target, args, kwargs, ret_value=None):
logging.info(f'start {name}.')
process = Process(target=target, args=args, kwargs=kwargs)
process.start()
process.join()
if ret_value is not None:
if ret_value.value != 0:
logging.error(f'{name} failed.')
exit()
else:
logging.info(f'{name} success.')
def main():
args = parse_args()
set_start_method('spawn')
@ -52,51 +69,72 @@ def main():
ret_value = mp.Value('d', 0, lock=False)
# convert onnx
logging.info('start torch2onnx conversion.')
onnx_save_file = deploy_cfg['pytorch2onnx']['save_file']
process = Process(
create_process(
'torch2onnx',
target=torch2onnx,
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
model_cfg_path, checkpoint_path),
kwargs=dict(device=args.device, ret_value=ret_value))
process.start()
process.join()
if ret_value.value != 0:
logging.error('torch2onnx failed.')
exit()
else:
logging.info('torch2onnx success.')
kwargs=dict(device=args.device, ret_value=ret_value),
ret_value=ret_value)
# convert backend
onnx_pathes = [osp.join(args.work_dir, onnx_save_file)]
onnx_files = [osp.join(args.work_dir, onnx_save_file)]
backend_files = onnx_files
backend = deploy_cfg.get('backend', 'default')
if backend == 'tensorrt':
assert hasattr(deploy_cfg, 'tensorrt_params')
tensorrt_params = deploy_cfg['tensorrt_params']
model_params = tensorrt_params.get('model_params', [])
assert len(model_params) == len(onnx_pathes)
assert len(model_params) == len(onnx_files)
logging.info('start onnx2tensorrt conversion.')
from mmdeploy.apis.tensorrt import onnx2tensorrt
backend_files = []
for model_id, model_param, onnx_path in zip(
range(len(onnx_pathes)), model_params, onnx_pathes):
range(len(onnx_files)), model_params, onnx_files):
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
save_file = model_param.get('save_file', onnx_name + '.engine')
process = Process(
create_process(
f'onnx2tensorrt of {onnx_path}',
target=onnx2tensorrt,
args=(args.work_dir, save_file, model_id, deploy_cfg_path,
onnx_path),
kwargs=dict(device=args.device, ret_value=ret_value))
process.start()
process.join()
kwargs=dict(device=args.device, ret_value=ret_value),
ret_value=ret_value)
if ret_value.value != 0:
logging.error('onnx2tensorrt failed.')
exit()
else:
logging.info('onnx2tensorrt success.')
backend_files.append(osp.join(args.work_dir, save_file))
# check model outputs by visualization
codebase = deploy_cfg['codebase']
# visualize tensorrt model
create_process(
f'visualize {backend} model',
target=inference_model,
args=(model_cfg_path, backend_files, args.img),
kwargs=dict(
codebase=codebase,
backend=backend,
device=args.device,
output_file=f'output_{backend}.jpg',
show_result=args.show,
ret_value=ret_value))
# visualize pytorch model
create_process(
'visualize pytorch model',
target=inference_model,
args=(model_cfg_path, [checkpoint_path], args.img),
kwargs=dict(
codebase=codebase,
backend='pytorch',
device=args.device,
output_file='output_pytorch.jpg',
show_result=args.show,
ret_value=ret_value),
ret_value=ret_value)
logging.info('All process success.')