[Fix] fix device error in dump-info (#912)

* fix device error in dump-info

* fix UT
This commit is contained in:
AllentDan 2022-08-17 18:01:26 +08:00 committed by GitHub
parent 3e7e80bcbc
commit cad0092801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 23 deletions

View File

@ -42,20 +42,21 @@ def get_task(deploy_cfg: mmcv.Config) -> Dict:
def get_model_name_customs(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, def get_model_name_customs(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
work_dir: str) -> Tuple: work_dir: str, device: str) -> Tuple:
"""Get the model name and dump custom file. """Get the model name and dump custom file.
Args: Args:
deploy_cfg (mmcv.Config): Deploy config dict. deploy_cfg (mmcv.Config): Deploy config dict.
model_cfg (mmcv.Config): The model config dict. model_cfg (mmcv.Config): The model config dict.
work_dir (str): Work dir to save json files. work_dir (str): Work dir to save json files.
device (str): The device passed in.
Return: Return:
tuple(): Composed of the model name and the custom info. tuple(): Composed of the model name and the custom info.
""" """
task = get_task_type(deploy_cfg) task = get_task_type(deploy_cfg)
task_processor = build_task_processor( task_processor = build_task_processor(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu') model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
name = task_processor.get_model_name() name = task_processor.get_model_name()
customs = [] customs = []
if task == Task.TEXT_RECOGNITION: if task == Task.TEXT_RECOGNITION:
@ -75,19 +76,21 @@ def get_model_name_customs(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
def get_models(deploy_cfg: Union[str, mmcv.Config], def get_models(deploy_cfg: Union[str, mmcv.Config],
model_cfg: Union[str, mmcv.Config], work_dir: str) -> List: model_cfg: Union[str, mmcv.Config], work_dir: str,
device: str) -> List:
"""Get the output model informantion for deploy.json. """Get the output model informantion for deploy.json.
Args: Args:
deploy_cfg (mmcv.Config): Deploy config dict. deploy_cfg (mmcv.Config): Deploy config dict.
model_cfg (mmcv.Config): The model config dict. model_cfg (mmcv.Config): The model config dict.
work_dir (str): Work dir to save json files. work_dir (str): Work dir to save json files.
device (str): The device passed in.
Return: Return:
list[dict]: The list contains dicts composed of the model name, net, list[dict]: The list contains dicts composed of the model name, net,
weghts, backend, precision batchsize and dynamic_shape. weghts, backend, precision batchsize and dynamic_shape.
""" """
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir) name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device)
precision = 'FP32' precision = 'FP32'
ir_name = get_ir_config(deploy_cfg)['save_file'] ir_name = get_ir_config(deploy_cfg)['save_file']
net = ir_name net = ir_name
@ -149,19 +152,20 @@ def get_models(deploy_cfg: Union[str, mmcv.Config],
def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
work_dir: str) -> Dict: work_dir: str, device: str) -> Dict:
"""Get the inference information for pipeline.json. """Get the inference information for pipeline.json.
Args: Args:
deploy_cfg (mmcv.Config): Deploy config dict. deploy_cfg (mmcv.Config): Deploy config dict.
model_cfg (mmcv.Config): The model config dict. model_cfg (mmcv.Config): The model config dict.
work_dir (str): Work dir to save json files. work_dir (str): Work dir to save json files.
device (str): The device passed in.
Return: Return:
dict: Composed of the model name, type, module, input, output and dict: Composed of the model name, type, module, input, output and
input_map. input_map.
""" """
name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir) name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device)
type = 'Task' type = 'Task'
module = 'Net' module = 'Net'
input = ['prep_output'] input = ['prep_output']
@ -182,9 +186,17 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
return return_dict return return_dict
def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config): def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
device: str):
"""Get the pre process information for pipeline.json.
Args:
deploy_cfg (mmcv.Config): Deploy config dict.
model_cfg (mmcv.Config): The model config dict.
device (str): The device passed in.
"""
task_processor = build_task_processor( task_processor = build_task_processor(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu') model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
pipeline = task_processor.get_preprocess() pipeline = task_processor.get_preprocess()
type = 'Task' type = 'Task'
module = 'Transform' module = 'Transform'
@ -233,12 +245,14 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config):
transforms=transforms) transforms=transforms)
def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict: def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
device: str, **kwargs) -> Dict:
"""Get the post process information for pipeline.json. """Get the post process information for pipeline.json.
Args: Args:
deploy_cfg (mmcv.Config): Deploy config dict. deploy_cfg (mmcv.Config): Deploy config dict.
model_cfg (mmcv.Config): The model config dict. model_cfg (mmcv.Config): The model config dict.
device (str): The device passed in.
Return: Return:
dict: Composed of the model name, type, module, input, params and dict: Composed of the model name, type, module, input, params and
@ -249,7 +263,7 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict:
name = 'postprocess' name = 'postprocess'
task = get_task_type(deploy_cfg) task = get_task_type(deploy_cfg)
task_processor = build_task_processor( task_processor = build_task_processor(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device='cpu') model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
params = task_processor.get_postprocess() params = task_processor.get_postprocess()
# TODO remove after adding instance segmentation to task processor # TODO remove after adding instance segmentation to task processor
@ -270,14 +284,15 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config) -> Dict:
output=output) output=output)
def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, work_dir: str,
work_dir: str) -> Dict: device: str) -> Dict:
"""Get the inference information for pipeline.json. """Get the inference information for pipeline.json.
Args: Args:
deploy_cfg (mmcv.Config): Deploy config dict. deploy_cfg (mmcv.Config): Deploy config dict.
model_cfg (mmcv.Config): The model config dict. model_cfg (mmcv.Config): The model config dict.
work_dir (str): Work dir to save json files. work_dir (str): Work dir to save json files.
device (str): The device passed in.
Return: Return:
dict: Composed of version, task, models and customs. dict: Composed of version, task, models and customs.
@ -286,27 +301,28 @@ def get_deploy(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
task = get_task_type(deploy_cfg) task = get_task_type(deploy_cfg)
cls_name = task_map[task]['cls_name'] cls_name = task_map[task]['cls_name']
_, customs = get_model_name_customs( _, customs = get_model_name_customs(
deploy_cfg, model_cfg, work_dir=work_dir) deploy_cfg, model_cfg, work_dir=work_dir, device=device)
version = get_mmdpeloy_version() version = get_mmdpeloy_version()
models = get_models(deploy_cfg, model_cfg, work_dir=work_dir) models = get_models(deploy_cfg, model_cfg, work_dir, device)
return dict(version=version, task=cls_name, models=models, customs=customs) return dict(version=version, task=cls_name, models=models, customs=customs)
def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, def get_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
work_dir: str) -> Dict: work_dir: str, device: str) -> Dict:
"""Get the inference information for pipeline.json. """Get the inference information for pipeline.json.
Args: Args:
deploy_cfg (mmcv.Config): Deploy config dict. deploy_cfg (mmcv.Config): Deploy config dict.
model_cfg (mmcv.Config): The model config dict. model_cfg (mmcv.Config): The model config dict.
work_dir (str): Work dir to save json files. work_dir (str): Work dir to save json files.
device (str): The device passed in.
Return: Return:
dict: Composed of input node name, output node name and the tasks. dict: Composed of input node name, output node name and the tasks.
""" """
preprocess = get_preprocess(deploy_cfg, model_cfg) preprocess = get_preprocess(deploy_cfg, model_cfg, device)
infer_info = get_inference_info(deploy_cfg, model_cfg, work_dir=work_dir) infer_info = get_inference_info(deploy_cfg, model_cfg, work_dir, device)
postprocess = get_postprocess(deploy_cfg, model_cfg) postprocess = get_postprocess(deploy_cfg, model_cfg, device)
task = get_task_type(deploy_cfg) task = get_task_type(deploy_cfg)
input_names = preprocess['input'] input_names = preprocess['input']
output_names = postprocess['output'] output_names = postprocess['output']
@ -353,7 +369,8 @@ def get_detail(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
def export2SDK(deploy_cfg: Union[str, mmcv.Config], def export2SDK(deploy_cfg: Union[str, mmcv.Config],
model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str): model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str,
device: str, **kwargs):
"""Export information to SDK. This function dump `deploy.json`, """Export information to SDK. This function dump `deploy.json`,
`pipeline.json` and `detail.json` to work dir. `pipeline.json` and `detail.json` to work dir.
@ -364,8 +381,8 @@ def export2SDK(deploy_cfg: Union[str, mmcv.Config],
pth (str): The path of the model checkpoint weights. pth (str): The path of the model checkpoint weights.
""" """
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir=work_dir) deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir, device)
pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir=work_dir) pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir, device)
detail_info = get_detail(deploy_cfg, model_cfg, pth=pth) detail_info = get_detail(deploy_cfg, model_cfg, pth=pth)
mmcv.dump( mmcv.dump(
deploy_info, deploy_info,

View File

@ -417,7 +417,7 @@ def test_AdvancedEnum():
not importlib.util.find_spec('mmedit'), reason='requires mmedit') not importlib.util.find_spec('mmedit'), reason='requires mmedit')
def test_export_info(): def test_export_info():
with tempfile.TemporaryDirectory() as dir: with tempfile.TemporaryDirectory() as dir:
export2SDK(correct_deploy_cfg, correct_model_cfg, dir, '') export2SDK(correct_deploy_cfg, correct_model_cfg, dir, '', 'cpu')
deploy_json = os.path.join(dir, 'deploy.json') deploy_json = os.path.join(dir, 'deploy.json')
pipeline_json = os.path.join(dir, 'pipeline.json') pipeline_json = os.path.join(dir, 'pipeline.json')
detail_json = os.path.join(dir, 'detail.json') detail_json = os.path.join(dir, 'detail.json')

View File

@ -122,7 +122,12 @@ def main():
mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
if args.dump_info: if args.dump_info:
export2SDK(deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path) export2SDK(
deploy_cfg,
model_cfg,
args.work_dir,
pth=checkpoint_path,
device=args.device)
ret_value = mp.Value('d', 0, lock=False) ret_value = mp.Value('d', 0, lock=False)