add return value in deploy.py
parent
34caea27c1
commit
b777d27bcf
|
@ -3,6 +3,7 @@ from typing import Any, Optional, Union
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from mmdeploy.utils import (RewriterContext, patch_model,
|
from mmdeploy.utils import (RewriterContext, patch_model,
|
||||||
register_extra_symbolics)
|
register_extra_symbolics)
|
||||||
|
@ -14,7 +15,10 @@ def torch2onnx(img: Any,
|
||||||
deploy_cfg: Union[str, mmcv.Config],
|
deploy_cfg: Union[str, mmcv.Config],
|
||||||
model_cfg: Union[str, mmcv.Config],
|
model_cfg: Union[str, mmcv.Config],
|
||||||
model_checkpoint: Optional[str] = None,
|
model_checkpoint: Optional[str] = None,
|
||||||
device: str = 'cuda:0'):
|
device: str = 'cuda:0',
|
||||||
|
ret_value: Optional[mp.Value] = None):
|
||||||
|
ret_value.value = -1
|
||||||
|
|
||||||
# load deploy_cfg if needed
|
# load deploy_cfg if needed
|
||||||
if isinstance(deploy_cfg, str):
|
if isinstance(deploy_cfg, str):
|
||||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
|
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
|
||||||
|
@ -40,8 +44,7 @@ def torch2onnx(img: Any,
|
||||||
|
|
||||||
torch_model = init_model(codebase, model_cfg, model_checkpoint, device)
|
torch_model = init_model(codebase, model_cfg, model_checkpoint, device)
|
||||||
data, model_inputs = create_input(codebase, model_cfg, img, device)
|
data, model_inputs = create_input(codebase, model_cfg, img, device)
|
||||||
patched_model = patch_model(
|
patched_model = patch_model(torch_model, cfg=deploy_cfg, backend=backend)
|
||||||
torch_model, cfg=deploy_cfg, backend=backend, data=data)
|
|
||||||
|
|
||||||
if not isinstance(model_inputs, torch.Tensor):
|
if not isinstance(model_inputs, torch.Tensor):
|
||||||
model_inputs = model_inputs[0]
|
model_inputs = model_inputs[0]
|
||||||
|
@ -57,3 +60,5 @@ def torch2onnx(img: Any,
|
||||||
dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None),
|
dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None),
|
||||||
keep_initializers_as_inputs=pytorch2onnx_cfg[
|
keep_initializers_as_inputs=pytorch2onnx_cfg[
|
||||||
'keep_initializers_as_inputs'])
|
'keep_initializers_as_inputs'])
|
||||||
|
|
||||||
|
ret_value.value = 0
|
||||||
|
|
|
@ -70,10 +70,11 @@ MODULE_REWRITERS = RewriteModuleRegistry(
|
||||||
def patch_model(model: nn.Module,
|
def patch_model(model: nn.Module,
|
||||||
cfg: Dict,
|
cfg: Dict,
|
||||||
backend: str = 'default',
|
backend: str = 'default',
|
||||||
|
recursive: bool = True,
|
||||||
**kwargs) -> nn.Module:
|
**kwargs) -> nn.Module:
|
||||||
|
|
||||||
def _patch_impl(model, cfg, **kwargs):
|
def _patch_impl(model, cfg, **kwargs):
|
||||||
if hasattr(model, 'named_children'):
|
if recursive and hasattr(model, 'named_children'):
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
model._modules[name] = _patch_impl(module, cfg, **kwargs)
|
model._modules[name] = _patch_impl(module, cfg, **kwargs)
|
||||||
return MODULE_REWRITERS.build(
|
return MODULE_REWRITERS.build(
|
||||||
|
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
|
import torch.multiprocessing as mp
|
||||||
from torch.multiprocessing import Process, set_start_method
|
from torch.multiprocessing import Process, set_start_method
|
||||||
|
|
||||||
from mmdeploy.apis import torch2onnx
|
from mmdeploy.apis import torch2onnx
|
||||||
|
@ -34,15 +35,23 @@ def main():
|
||||||
# create work_dir if not
|
# create work_dir if not
|
||||||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
||||||
|
|
||||||
|
ret_value = mp.Value('d', 0, lock=False)
|
||||||
|
|
||||||
# convert model
|
# convert model
|
||||||
logging.info('start torch2onnx conversion.')
|
logging.info('start torch2onnx conversion.')
|
||||||
process = Process(
|
process = Process(
|
||||||
target=torch2onnx,
|
target=torch2onnx,
|
||||||
args=(args.img, args.work_dir, deploy_cfg, model_cfg, checkpoint),
|
args=(args.img, args.work_dir, deploy_cfg, model_cfg, checkpoint),
|
||||||
kwargs=dict(device=args.device))
|
kwargs=dict(device=args.device, ret_value=ret_value))
|
||||||
process.start()
|
process.start()
|
||||||
process.join()
|
process.join()
|
||||||
|
|
||||||
|
if ret_value.value != 0:
|
||||||
|
logging.error('torch2onnx failed.')
|
||||||
|
exit()
|
||||||
|
else:
|
||||||
|
logging.info('torch2onnx success.')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue