mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] fix device error in dump-info (#912)
* fix device error in dump-info * fix UT
This commit is contained in:
parent
3e7e80bcbc
commit
cad0092801
@ -42,20 +42,21 @@ def get_task(deploy_cfg: mmcv.Config) -> Dict:
|
|||||||
|
|
||||||
|
|
||||||
def get_model_name_customs(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
def get_model_name_customs(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||||
work_dir: str) -> Tuple:
|
work_dir: str, device: str) -> Tuple:
|
||||||
"""Get the model name and dump custom file.
|
"""Get the model name and dump custom file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
deploy_cfg (mmcv.Config): Deploy config dict.
|
deploy_cfg (mmcv.Config): Deploy config dict.
|
||||||
model_cfg (mmcv.Config): The model config dict.
|
model_cfg (mmcv.Config): The model config dict.
|
||||||
work_dir (str): Work dir to save json files.
|
work_dir (str): Work dir to save json files.
|
||||||
|
device (str): The device passed in.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
tuple(): Composed of the model name and the custom info.
|
tuple(): Composed of the model name and the custom info.
|
||||||
"""
|
"""
|
||||||
task = get_task_type(deploy_cfg)
|
task = get_task_type(deploy_cfg)
|
||||||
task_processor = build_task_processor(
|
task_processor = build_task_processor(
|
||||||
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu')
|
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
|
||||||
name = task_processor.get_model_name()
|
name = task_processor.get_model_name()
|
||||||
customs = []
|
customs = []
|
||||||
if task == Task.TEXT_RECOGNITION:
|
if task == Task.TEXT_RECOGNITION:
|
||||||
@ -75,19 +76,21 @@ def get_model_name_customs(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||||||
|
|
||||||
|
|
||||||
def get_models(deploy_cfg: Union[str, mmcv.Config],
|
def get_models(deploy_cfg: Union[str, mmcv.Config],
|
||||||
model_cfg: Union[str, mmcv.Config], work_dir: str) -> List:
|
model_cfg: Union[str, mmcv.Config], work_dir: str,
|
||||||
|
device: str) -> List:
|
||||||
"""Get the output model informantion for deploy.json.
|
"""Get the output model informantion for deploy.json.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
deploy_cfg (mmcv.Config): Deploy config dict.
|
deploy_cfg (mmcv.Config): Deploy config dict.
|
||||||
model_cfg (mmcv.Config): The model config dict.
|
model_cfg (mmcv.Config): The model config dict.
|
||||||
work_dir (str): Work dir to save json files.
|
work_dir (str): Work dir to save json files.
|
||||||
|
device (str): The device passed in.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
list[dict]: The list contains dicts composed of the model name, net,
|
list[dict]: The list contains dicts composed of the model name, net,
|
||||||
weghts, backend, precision batchsize and dynamic_shape.
|
weghts, backend, precision batchsize and dynamic_shape.
|
||||||
"""
|
"""
|
||||||
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir)
|
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device)
|
||||||
precision = 'FP32'
|
precision = 'FP32'
|
||||||
ir_name = get_ir_config(deploy_cfg)['save_file']
|
ir_name = get_ir_config(deploy_cfg)['save_file']
|
||||||
net = ir_name
|
net = ir_name
|
||||||
@ -149,19 +152,20 @@ def get_models(deploy_cfg: Union[str, mmcv.Config],
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||||
work_dir: str) -> Dict:
|
work_dir: str, device: str) -> Dict:
|
||||||
"""Get the inference information for pipeline.json.
|
"""Get the inference information for pipeline.json.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
deploy_cfg (mmcv.Config): Deploy config dict.
|
deploy_cfg (mmcv.Config): Deploy config dict.
|
||||||
model_cfg (mmcv.Config): The model config dict.
|
model_cfg (mmcv.Config): The model config dict.
|
||||||
work_dir (str): Work dir to save json files.
|
work_dir (str): Work dir to save json files.
|
||||||
|
device (str): The device passed in.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
dict: Composed of the model name, type, module, input, output and
|
dict: Composed of the model name, type, module, input, output and
|
||||||
input_map.
|
input_map.
|
||||||
"""
|
"""
|
||||||
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir)
|
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device)
|
||||||
type = 'Task'
|
type = 'Task'
|
||||||
module = 'Net'
|
module = 'Net'
|
||||||
input = ['prep_output']
|
input = ['prep_output']
|
||||||
@ -182,9 +186,17 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config):
|
def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||||
|
device: str):
|
||||||
|
"""Get the pre process information for pipeline.json.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deploy_cfg (mmcv.Config): Deploy config dict.
|
||||||
|
model_cfg (mmcv.Config): The model config dict.
|
||||||
|
device (str): The device passed in.
|
||||||
|
"""
|
||||||
task_processor = build_task_processor(
|
task_processor = build_task_processor(
|
||||||
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu')
|
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
|
||||||
pipeline = task_processor.get_preprocess()
|
pipeline = task_processor.get_preprocess()
|
||||||
type = 'Task'
|
type = 'Task'
|
||||||
module = 'Transform'
|
module = 'Transform'
|
||||||
@ -233,12 +245,14 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config):
|
|||||||
transforms=transforms)
|
transforms=transforms)
|
||||||
|
|
||||||
|
|
||||||
def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict:
|
def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||||
|
device: str, **kwargs) -> Dict:
|
||||||
"""Get the post process information for pipeline.json.
|
"""Get the post process information for pipeline.json.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
deploy_cfg (mmcv.Config): Deploy config dict.
|
deploy_cfg (mmcv.Config): Deploy config dict.
|
||||||
model_cfg (mmcv.Config): The model config dict.
|
model_cfg (mmcv.Config): The model config dict.
|
||||||
|
device (str): The device passed in.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
dict: Composed of the model name, type, module, input, params and
|
dict: Composed of the model name, type, module, input, params and
|
||||||
@ -249,7 +263,7 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict:
|
|||||||
name = 'postprocess'
|
name = 'postprocess'
|
||||||
task = get_task_type(deploy_cfg)
|
task = get_task_type(deploy_cfg)
|
||||||
task_processor = build_task_processor(
|
task_processor = build_task_processor(
|
||||||
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu')
|
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
|
||||||
params = task_processor.get_postprocess()
|
params = task_processor.get_postprocess()
|
||||||
|
|
||||||
# TODO remove after adding instance segmentation to task processor
|
# TODO remove after adding instance segmentation to task processor
|
||||||
@ -270,14 +284,15 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict:
|
|||||||
output=output)
|
output=output)
|
||||||
|
|
||||||
|
|
||||||
def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, work_dir: str,
|
||||||
work_dir: str) -> Dict:
|
device: str) -> Dict:
|
||||||
"""Get the inference information for pipeline.json.
|
"""Get the inference information for pipeline.json.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
deploy_cfg (mmcv.Config): Deploy config dict.
|
deploy_cfg (mmcv.Config): Deploy config dict.
|
||||||
model_cfg (mmcv.Config): The model config dict.
|
model_cfg (mmcv.Config): The model config dict.
|
||||||
work_dir (str): Work dir to save json files.
|
work_dir (str): Work dir to save json files.
|
||||||
|
device (str): The device passed in.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
dict: Composed of version, task, models and customs.
|
dict: Composed of version, task, models and customs.
|
||||||
@ -286,27 +301,28 @@ def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||||||
task = get_task_type(deploy_cfg)
|
task = get_task_type(deploy_cfg)
|
||||||
cls_name = task_map[task]['cls_name']
|
cls_name = task_map[task]['cls_name']
|
||||||
_, customs = get_model_name_customs(
|
_, customs = get_model_name_customs(
|
||||||
deploy_cfg, model_cfg, work_dir=work_dir)
|
deploy_cfg, model_cfg, work_dir=work_dir, device=device)
|
||||||
version = get_mmdpeloy_version()
|
version = get_mmdpeloy_version()
|
||||||
models = get_models(deploy_cfg, model_cfg, work_dir=work_dir)
|
models = get_models(deploy_cfg, model_cfg, work_dir, device)
|
||||||
return dict(version=version, task=cls_name, models=models, customs=customs)
|
return dict(version=version, task=cls_name, models=models, customs=customs)
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
||||||
work_dir: str) -> Dict:
|
work_dir: str, device: str) -> Dict:
|
||||||
"""Get the inference information for pipeline.json.
|
"""Get the inference information for pipeline.json.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
deploy_cfg (mmcv.Config): Deploy config dict.
|
deploy_cfg (mmcv.Config): Deploy config dict.
|
||||||
model_cfg (mmcv.Config): The model config dict.
|
model_cfg (mmcv.Config): The model config dict.
|
||||||
work_dir (str): Work dir to save json files.
|
work_dir (str): Work dir to save json files.
|
||||||
|
device (str): The device passed in.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
dict: Composed of input node name, output node name and the tasks.
|
dict: Composed of input node name, output node name and the tasks.
|
||||||
"""
|
"""
|
||||||
preprocess = get_preprocess(deploy_cfg, model_cfg)
|
preprocess = get_preprocess(deploy_cfg, model_cfg, device)
|
||||||
infer_info = get_inference_info(deploy_cfg, model_cfg, work_dir=work_dir)
|
infer_info = get_inference_info(deploy_cfg, model_cfg, work_dir, device)
|
||||||
postprocess = get_postprocess(deploy_cfg, model_cfg)
|
postprocess = get_postprocess(deploy_cfg, model_cfg, device)
|
||||||
task = get_task_type(deploy_cfg)
|
task = get_task_type(deploy_cfg)
|
||||||
input_names = preprocess['input']
|
input_names = preprocess['input']
|
||||||
output_names = postprocess['output']
|
output_names = postprocess['output']
|
||||||
@ -353,7 +369,8 @@ def get_detail(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||||||
|
|
||||||
|
|
||||||
def export2SDK(deploy_cfg: Union[str, mmcv.Config],
|
def export2SDK(deploy_cfg: Union[str, mmcv.Config],
|
||||||
model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str):
|
model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str,
|
||||||
|
device: str, **kwargs):
|
||||||
"""Export information to SDK. This function dump `deploy.json`,
|
"""Export information to SDK. This function dump `deploy.json`,
|
||||||
`pipeline.json` and `detail.json` to work dir.
|
`pipeline.json` and `detail.json` to work dir.
|
||||||
|
|
||||||
@ -364,8 +381,8 @@ def export2SDK(deploy_cfg: Union[str, mmcv.Config],
|
|||||||
pth (str): The path of the model checkpoint weights.
|
pth (str): The path of the model checkpoint weights.
|
||||||
"""
|
"""
|
||||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||||
deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir=work_dir)
|
deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir, device)
|
||||||
pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir=work_dir)
|
pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir, device)
|
||||||
detail_info = get_detail(deploy_cfg, model_cfg, pth=pth)
|
detail_info = get_detail(deploy_cfg, model_cfg, pth=pth)
|
||||||
mmcv.dump(
|
mmcv.dump(
|
||||||
deploy_info,
|
deploy_info,
|
||||||
|
@ -417,7 +417,7 @@ def test_AdvancedEnum():
|
|||||||
not importlib.util.find_spec('mmedit'), reason='requires mmedit')
|
not importlib.util.find_spec('mmedit'), reason='requires mmedit')
|
||||||
def test_export_info():
|
def test_export_info():
|
||||||
with tempfile.TemporaryDirectory() as dir:
|
with tempfile.TemporaryDirectory() as dir:
|
||||||
export2SDK(correct_deploy_cfg, correct_model_cfg, dir, '')
|
export2SDK(correct_deploy_cfg, correct_model_cfg, dir, '', 'cpu')
|
||||||
deploy_json = os.path.join(dir, 'deploy.json')
|
deploy_json = os.path.join(dir, 'deploy.json')
|
||||||
pipeline_json = os.path.join(dir, 'pipeline.json')
|
pipeline_json = os.path.join(dir, 'pipeline.json')
|
||||||
detail_json = os.path.join(dir, 'detail.json')
|
detail_json = os.path.join(dir, 'detail.json')
|
||||||
|
@ -122,7 +122,12 @@ def main():
|
|||||||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
||||||
|
|
||||||
if args.dump_info:
|
if args.dump_info:
|
||||||
export2SDK(deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path)
|
export2SDK(
|
||||||
|
deploy_cfg,
|
||||||
|
model_cfg,
|
||||||
|
args.work_dir,
|
||||||
|
pth=checkpoint_path,
|
||||||
|
device=args.device)
|
||||||
|
|
||||||
ret_value = mp.Value('d', 0, lock=False)
|
ret_value = mp.Value('d', 0, lock=False)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user