add return value in deploy.py

pull/12/head
grimoire 2021-06-17 17:26:32 +08:00
parent 34caea27c1
commit b777d27bcf
3 changed files with 20 additions and 5 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Optional, Union
import mmcv
import torch
import torch.multiprocessing as mp
from mmdeploy.utils import (RewriterContext, patch_model,
register_extra_symbolics)
@ -14,7 +15,10 @@ def torch2onnx(img: Any,
deploy_cfg: Union[str, mmcv.Config],
model_cfg: Union[str, mmcv.Config],
model_checkpoint: Optional[str] = None,
device: str = 'cuda:0'):
device: str = 'cuda:0',
ret_value: Optional[mp.Value] = None):
ret_value.value = -1
# load deploy_cfg if needed
if isinstance(deploy_cfg, str):
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
@ -40,8 +44,7 @@ def torch2onnx(img: Any,
torch_model = init_model(codebase, model_cfg, model_checkpoint, device)
data, model_inputs = create_input(codebase, model_cfg, img, device)
patched_model = patch_model(
torch_model, cfg=deploy_cfg, backend=backend, data=data)
patched_model = patch_model(torch_model, cfg=deploy_cfg, backend=backend)
if not isinstance(model_inputs, torch.Tensor):
model_inputs = model_inputs[0]
@ -57,3 +60,5 @@ def torch2onnx(img: Any,
dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None),
keep_initializers_as_inputs=pytorch2onnx_cfg[
'keep_initializers_as_inputs'])
ret_value.value = 0

View File

@ -70,10 +70,11 @@ MODULE_REWRITERS = RewriteModuleRegistry(
def patch_model(model: nn.Module,
cfg: Dict,
backend: str = 'default',
recursive: bool = True,
**kwargs) -> nn.Module:
def _patch_impl(model, cfg, **kwargs):
if hasattr(model, 'named_children'):
if recursive and hasattr(model, 'named_children'):
for name, module in model.named_children():
model._modules[name] = _patch_impl(module, cfg, **kwargs)
return MODULE_REWRITERS.build(

View File

@ -3,6 +3,7 @@ import logging
import os.path as osp
import mmcv
import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method
from mmdeploy.apis import torch2onnx
@ -34,15 +35,23 @@ def main():
# create work_dir if not
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
ret_value = mp.Value('d', 0, lock=False)
# convert model
logging.info('start torch2onnx conversion.')
process = Process(
target=torch2onnx,
args=(args.img, args.work_dir, deploy_cfg, model_cfg, checkpoint),
kwargs=dict(device=args.device))
kwargs=dict(device=args.device, ret_value=ret_value))
process.start()
process.join()
if ret_value.value != 0:
logging.error('torch2onnx failed.')
exit()
else:
logging.info('torch2onnx success.')
if __name__ == '__main__':
main()