diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index f8999cc30..66f360286 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -3,6 +3,7 @@ from typing import Any, Optional, Union import mmcv import torch +import torch.multiprocessing as mp from mmdeploy.utils import (RewriterContext, patch_model, register_extra_symbolics) @@ -14,7 +15,10 @@ def torch2onnx(img: Any, deploy_cfg: Union[str, mmcv.Config], model_cfg: Union[str, mmcv.Config], model_checkpoint: Optional[str] = None, - device: str = 'cuda:0'): + device: str = 'cuda:0', + ret_value: Optional[mp.Value] = None): + ret_value.value = -1 + # load deploy_cfg if needed if isinstance(deploy_cfg, str): deploy_cfg = mmcv.Config.fromfile(deploy_cfg) @@ -40,8 +44,7 @@ def torch2onnx(img: Any, torch_model = init_model(codebase, model_cfg, model_checkpoint, device) data, model_inputs = create_input(codebase, model_cfg, img, device) - patched_model = patch_model( - torch_model, cfg=deploy_cfg, backend=backend, data=data) + patched_model = patch_model(torch_model, cfg=deploy_cfg, backend=backend) if not isinstance(model_inputs, torch.Tensor): model_inputs = model_inputs[0] @@ -57,3 +60,5 @@ def torch2onnx(img: Any, dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None), keep_initializers_as_inputs=pytorch2onnx_cfg[ 'keep_initializers_as_inputs']) + + ret_value.value = 0 diff --git a/mmdeploy/utils/module_rewriter.py b/mmdeploy/utils/module_rewriter.py index ddf5addee..e1c2be5fb 100644 --- a/mmdeploy/utils/module_rewriter.py +++ b/mmdeploy/utils/module_rewriter.py @@ -70,10 +70,11 @@ MODULE_REWRITERS = RewriteModuleRegistry( def patch_model(model: nn.Module, cfg: Dict, backend: str = 'default', + recursive: bool = True, **kwargs) -> nn.Module: def _patch_impl(model, cfg, **kwargs): - if hasattr(model, 'named_children'): + if recursive and hasattr(model, 'named_children'): for name, module in model.named_children(): model._modules[name] = _patch_impl(module, cfg, **kwargs) return MODULE_REWRITERS.build( diff --git a/tools/deploy.py b/tools/deploy.py index 36d1b4e6c..bc3b406a9 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -3,6 +3,7 @@ import logging import os.path as osp import mmcv +import torch.multiprocessing as mp from torch.multiprocessing import Process, set_start_method from mmdeploy.apis import torch2onnx @@ -34,15 +35,23 @@ def main(): # create work_dir if not mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) + ret_value = mp.Value('d', 0, lock=False) + # convert model logging.info('start torch2onnx conversion.') process = Process( target=torch2onnx, args=(args.img, args.work_dir, deploy_cfg, model_cfg, checkpoint), - kwargs=dict(device=args.device)) + kwargs=dict(device=args.device, ret_value=ret_value)) process.start() process.join() + if ret_value.value != 0: + logging.error('torch2onnx failed.') + exit() + else: + logging.info('torch2onnx success.') + if __name__ == '__main__': main()