add return value in deploy.py
parent
34caea27c1
commit
b777d27bcf
|
@ -3,6 +3,7 @@ from typing import Any, Optional, Union
|
|||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from mmdeploy.utils import (RewriterContext, patch_model,
|
||||
register_extra_symbolics)
|
||||
|
@ -14,7 +15,10 @@ def torch2onnx(img: Any,
|
|||
deploy_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
model_checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0'):
|
||||
device: str = 'cuda:0',
|
||||
ret_value: Optional[mp.Value] = None):
|
||||
ret_value.value = -1
|
||||
|
||||
# load deploy_cfg if needed
|
||||
if isinstance(deploy_cfg, str):
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
|
||||
|
@ -40,8 +44,7 @@ def torch2onnx(img: Any,
|
|||
|
||||
torch_model = init_model(codebase, model_cfg, model_checkpoint, device)
|
||||
data, model_inputs = create_input(codebase, model_cfg, img, device)
|
||||
patched_model = patch_model(
|
||||
torch_model, cfg=deploy_cfg, backend=backend, data=data)
|
||||
patched_model = patch_model(torch_model, cfg=deploy_cfg, backend=backend)
|
||||
|
||||
if not isinstance(model_inputs, torch.Tensor):
|
||||
model_inputs = model_inputs[0]
|
||||
|
@ -57,3 +60,5 @@ def torch2onnx(img: Any,
|
|||
dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None),
|
||||
keep_initializers_as_inputs=pytorch2onnx_cfg[
|
||||
'keep_initializers_as_inputs'])
|
||||
|
||||
ret_value.value = 0
|
||||
|
|
|
@ -70,10 +70,11 @@ MODULE_REWRITERS = RewriteModuleRegistry(
|
|||
def patch_model(model: nn.Module,
|
||||
cfg: Dict,
|
||||
backend: str = 'default',
|
||||
recursive: bool = True,
|
||||
**kwargs) -> nn.Module:
|
||||
|
||||
def _patch_impl(model, cfg, **kwargs):
|
||||
if hasattr(model, 'named_children'):
|
||||
if recursive and hasattr(model, 'named_children'):
|
||||
for name, module in model.named_children():
|
||||
model._modules[name] = _patch_impl(module, cfg, **kwargs)
|
||||
return MODULE_REWRITERS.build(
|
||||
|
|
|
@ -3,6 +3,7 @@ import logging
|
|||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import torch.multiprocessing as mp
|
||||
from torch.multiprocessing import Process, set_start_method
|
||||
|
||||
from mmdeploy.apis import torch2onnx
|
||||
|
@ -34,15 +35,23 @@ def main():
|
|||
# create work_dir if not
|
||||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
||||
|
||||
ret_value = mp.Value('d', 0, lock=False)
|
||||
|
||||
# convert model
|
||||
logging.info('start torch2onnx conversion.')
|
||||
process = Process(
|
||||
target=torch2onnx,
|
||||
args=(args.img, args.work_dir, deploy_cfg, model_cfg, checkpoint),
|
||||
kwargs=dict(device=args.device))
|
||||
kwargs=dict(device=args.device, ret_value=ret_value))
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
if ret_value.value != 0:
|
||||
logging.error('torch2onnx failed.')
|
||||
exit()
|
||||
else:
|
||||
logging.info('torch2onnx success.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue