1249 lines
46 KiB
Python
1249 lines
46 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import logging
|
|
import subprocess
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import mmcv
|
|
import openpyxl
|
|
import pandas as pd
|
|
import yaml
|
|
from torch.hub import download_url_to_file
|
|
from torch.multiprocessing import set_start_method
|
|
|
|
import mmdeploy.version
|
|
from mmdeploy.utils import (get_backend, get_codebase, get_root_logger,
|
|
is_dynamic_shape, load_config)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Regression Test')
|
|
parser.add_argument(
|
|
'--codebase',
|
|
nargs='+',
|
|
help='regression test yaml path.',
|
|
default=[
|
|
'mmcls', 'mmdet', 'mmseg', 'mmpose', 'mmocr', 'mmedit', 'mmrotate'
|
|
])
|
|
parser.add_argument(
|
|
'-p',
|
|
'--performance',
|
|
default=False,
|
|
action='store_true',
|
|
help='test performance if it set')
|
|
parser.add_argument(
|
|
'--backends', nargs='+', help='test specific backend(s)')
|
|
parser.add_argument('--models', nargs='+', help='test specific model(s)')
|
|
parser.add_argument(
|
|
'--work-dir',
|
|
type=str,
|
|
help='the dir to save logs and models',
|
|
default='../mmdeploy_regression_working_dir')
|
|
parser.add_argument(
|
|
'--checkpoint-dir',
|
|
type=str,
|
|
help='the dir to save checkpoint for all model',
|
|
default='../mmdeploy_checkpoints')
|
|
parser.add_argument(
|
|
'--device', type=str, help='Device type, cuda or cpu', default='cuda')
|
|
parser.add_argument(
|
|
'--log-level',
|
|
help='set log level',
|
|
default='INFO',
|
|
choices=list(logging._nameToLevel.keys()))
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def merge_report(work_dir: str, logger: logging.Logger):
|
|
"""Merge all the report into one report.
|
|
|
|
Args:
|
|
work_dir (str): Work dir that including all reports.
|
|
logger (logging.Logger): Logger handler.
|
|
"""
|
|
work_dir = Path(work_dir)
|
|
res_file = work_dir.joinpath(
|
|
f'mmdeploy_regression_test_{mmdeploy.version.__version__}.xlsx')
|
|
logger.info(f'Whole result report saving in {res_file}')
|
|
if res_file.exists():
|
|
# delete if it existed
|
|
res_file.unlink()
|
|
for report_file in work_dir.iterdir():
|
|
if report_file.name.startswith('.~'):
|
|
# skip unclosed temp file
|
|
continue
|
|
if '_report.xlsx' not in report_file.name or \
|
|
report_file.name == res_file.name or \
|
|
not report_file.is_file():
|
|
# skip other file
|
|
continue
|
|
# get info from report
|
|
logger.info(f'Merging {report_file}')
|
|
df = pd.read_excel(str(report_file))
|
|
df.rename(columns={'Unnamed: 0': 'Index'}, inplace=True)
|
|
|
|
# get key then convert to list
|
|
keys = list(df.keys())
|
|
values = df.values.tolist()
|
|
|
|
# sheet name
|
|
sheet_name = report_file.stem.split('_')[0]
|
|
|
|
# begin to write
|
|
if res_file.exists():
|
|
# load if it existed
|
|
wb = openpyxl.load_workbook(str(res_file))
|
|
else:
|
|
wb = openpyxl.Workbook()
|
|
|
|
# delete if sheet already exist
|
|
if sheet_name in wb.sheetnames:
|
|
wb.remove(wb[sheet_name])
|
|
# create sheet
|
|
wb.create_sheet(title=sheet_name, index=0)
|
|
# write in row
|
|
wb[sheet_name].append(keys)
|
|
for value in values:
|
|
wb[sheet_name].append(value)
|
|
# delete the blank sheet
|
|
for name in wb.sheetnames:
|
|
ws = wb[name]
|
|
if ws.cell(1, 1).value is None:
|
|
wb.remove(ws)
|
|
# save to file
|
|
wb.save(str(res_file))
|
|
|
|
logger.info('Report merge successful.')
|
|
|
|
|
|
def get_model_metafile_info(global_info: dict, model_info: dict,
|
|
logger: logging.Logger):
|
|
"""Get model metafile information.
|
|
|
|
Args:
|
|
global_info (dict): global info from deploy yaml.
|
|
model_info (dict): model info from deploy yaml.
|
|
logger (logging.Logger): Logger handler.
|
|
|
|
Returns:
|
|
Dict: Meta info of each model config
|
|
"""
|
|
|
|
# get info from global_info and model_info
|
|
checkpoint_dir = global_info.get('checkpoint_dir', None)
|
|
assert checkpoint_dir is not None
|
|
|
|
codebase_dir = global_info.get('codebase_dir', None)
|
|
assert codebase_dir is not None
|
|
|
|
codebase_name = global_info.get('codebase_name', None)
|
|
assert codebase_name is not None
|
|
|
|
model_config_files = model_info.get('model_configs', [])
|
|
assert len(model_config_files) > 0
|
|
|
|
# make checkpoint save directory
|
|
model_name = _filter_string(model_info.get('name', 'model'))
|
|
checkpoint_save_dir = Path(checkpoint_dir).joinpath(
|
|
codebase_name, model_name)
|
|
checkpoint_save_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f'Saving checkpoint in {checkpoint_save_dir}')
|
|
|
|
# get model metafile info
|
|
metafile_path = Path(codebase_dir).joinpath(model_info.get('metafile'))
|
|
with open(metafile_path) as f:
|
|
metafile_info = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
model_meta_info = dict()
|
|
for meta_model in metafile_info.get('Models'):
|
|
if str(meta_model.get('Config')) not in model_config_files:
|
|
# skip if the model not in model_config_files
|
|
logger.warning(f'{meta_model.get("Config")} '
|
|
f'not in {model_config_files}, pls check ! '
|
|
'Skip it...')
|
|
continue
|
|
|
|
# get meta info
|
|
model_meta_info.update({meta_model.get('Config'): meta_model})
|
|
|
|
# get weight url
|
|
weights_url = meta_model.get('Weights')
|
|
weights_name = str(weights_url).split('/')[-1]
|
|
weights_save_path = checkpoint_save_dir.joinpath(weights_name)
|
|
if weights_save_path.exists() and \
|
|
not global_info.get('checkpoint_force_download', False):
|
|
logger.info(f'model {weights_name} exist, skip download it...')
|
|
continue
|
|
|
|
# Download weight
|
|
logger.info(f'Downloading {weights_url} to {weights_save_path}')
|
|
download_url_to_file(
|
|
weights_url, str(weights_save_path), progress=True)
|
|
|
|
# check weather the weight download successful
|
|
if not weights_save_path.exists():
|
|
raise FileExistsError(f'Weight {weights_name} download fail')
|
|
|
|
logger.info('All models had been downloaded successful !')
|
|
return model_meta_info, checkpoint_save_dir, codebase_dir
|
|
|
|
|
|
def update_report(report_dict: dict, model_name: str, model_config: str,
|
|
task_name: str, checkpoint: str, dataset: str,
|
|
backend_name: str, deploy_config: str,
|
|
static_or_dynamic: str, precision_type: str,
|
|
conversion_result: str, fps: str, metric_info: list,
|
|
test_pass: str, report_txt_path: Path, codebase_name: str):
|
|
"""Update report information.
|
|
|
|
Args:
|
|
report_dict (dict): Report info dict.
|
|
model_name (str): Model name.
|
|
model_config (str): Model config name.
|
|
task_name (str): Task name.
|
|
checkpoint (str): Model checkpoint name.
|
|
dataset (str): Dataset name.
|
|
backend_name (str): Backend name.
|
|
deploy_config (str): Deploy config name.
|
|
static_or_dynamic (str): Static or dynamic.
|
|
precision_type (str): Precision type of the model.
|
|
conversion_result (str): Conversion result: Successful or Fail.
|
|
fps (str): Inference speed (ms/im).
|
|
metric_info (list): Metric info list of the ${modelName}.yml.
|
|
test_pass (str): Test result: Pass or Fail.
|
|
report_txt_path (Path): Report txt path.
|
|
codebase_name (str): Codebase name.
|
|
"""
|
|
# make model path shorter
|
|
if '.pth' in checkpoint:
|
|
checkpoint = Path(checkpoint).absolute().resolve()
|
|
checkpoint = str(checkpoint).split(f'/{codebase_name}/')[-1]
|
|
checkpoint = '${CHECKPOINT_DIR}' + f'/{codebase_name}/{checkpoint}'
|
|
else:
|
|
if Path(checkpoint).exists():
|
|
# To invoice the path which is 'A.a B.b' when test sdk.
|
|
checkpoint = Path(checkpoint).absolute().resolve()
|
|
elif backend_name == 'ncnn':
|
|
# ncnn have 2 backend file but only need xxx.param
|
|
checkpoint = checkpoint.split('.param')[0] + '.param'
|
|
work_dir = report_txt_path.parent.absolute().resolve()
|
|
checkpoint = str(checkpoint).replace(str(work_dir), '${WORK_DIR}')
|
|
|
|
# save to tmp file
|
|
tmp_str = f'{model_name},{model_config},{task_name},{checkpoint},' \
|
|
f'{dataset},{backend_name},{deploy_config},' \
|
|
f'{static_or_dynamic},{precision_type},{conversion_result},' \
|
|
f'{fps},'
|
|
|
|
# save to report
|
|
report_dict.get('Model').append(model_name)
|
|
report_dict.get('Model Config').append(model_config)
|
|
report_dict.get('Task').append(task_name)
|
|
report_dict.get('Checkpoint').append(checkpoint)
|
|
report_dict.get('Dataset').append(dataset)
|
|
report_dict.get('Backend').append(backend_name)
|
|
report_dict.get('Deploy Config').append(deploy_config)
|
|
report_dict.get('Static or Dynamic').append(static_or_dynamic)
|
|
report_dict.get('Precision Type').append(precision_type)
|
|
report_dict.get('Conversion Result').append(conversion_result)
|
|
# report_dict.get('FPS').append(fps)
|
|
|
|
for metric in metric_info:
|
|
for metric_name, metric_value in metric.items():
|
|
metric_name = str(metric_name)
|
|
report_dict.get(metric_name).append(metric_value)
|
|
tmp_str += f'{metric_value},'
|
|
report_dict.get('Test Pass').append(test_pass)
|
|
|
|
tmp_str += f'{test_pass}\n'
|
|
|
|
with open(report_txt_path, 'a+', encoding='utf-8') as f:
|
|
f.write(tmp_str)
|
|
|
|
|
|
def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
|
|
model_config_path: Path, model_config_name: str,
|
|
test_yaml_metric_info: dict, report_dict: dict,
|
|
logger: logging.Logger, report_txt_path: Path,
|
|
codebase_name: str):
|
|
"""Get metric from metafile info of the model.
|
|
|
|
Args:
|
|
model_name (str): Name of model.
|
|
meta_info (dict): Metafile info from model's metafile.yml.
|
|
checkpoint_path (Path): Checkpoint path.
|
|
model_config_path (Path): Model config path.
|
|
model_config_name (str): Name of model config in meta_info.
|
|
test_yaml_metric_info (dict): Metrics info from test yaml.
|
|
report_dict (dict): Report info dict.
|
|
logger (logging.Logger): Logger.
|
|
report_txt_path (Path): Report txt path.
|
|
codebase_name (str): Codebase name.
|
|
|
|
Returns:
|
|
Dict: metric info of the model
|
|
"""
|
|
|
|
if model_config_name not in meta_info:
|
|
logger.warning(
|
|
f'{model_config_name} not in meta_info, which is {meta_info}')
|
|
return {}
|
|
|
|
# get metric
|
|
model_info = meta_info.get(model_config_name, None)
|
|
metafile_metric_info = model_info.get('Results', None)
|
|
|
|
metric_list = []
|
|
pytorch_metric = dict()
|
|
dataset_type = ''
|
|
task_type = ''
|
|
|
|
# Get dataset
|
|
using_dataset = dict()
|
|
for _, v in test_yaml_metric_info.items():
|
|
if v.get('dataset') is None:
|
|
continue
|
|
dataset_list = v.get('dataset', [])
|
|
if not isinstance(dataset_list, list):
|
|
dataset_list = [dataset_list]
|
|
for metric_dataset in dataset_list:
|
|
dataset_tmp = using_dataset.get(metric_dataset, [])
|
|
if v.get('task_name') not in dataset_tmp:
|
|
dataset_tmp.append(v.get('task_name'))
|
|
using_dataset.update({metric_dataset: dataset_tmp})
|
|
|
|
# Get metrics info from metafile
|
|
for metafile_metric in metafile_metric_info:
|
|
pytorch_meta_metric = metafile_metric.get('Metrics')
|
|
|
|
dataset = metafile_metric.get('Dataset', '')
|
|
task_name = metafile_metric.get('Task', '')
|
|
|
|
if task_name == 'Restorers':
|
|
# mmedit
|
|
dataset = 'Set5'
|
|
|
|
if (len(using_dataset) > 1) and (dataset not in using_dataset):
|
|
logger.info(f'dataset not in {using_dataset}, skip it...')
|
|
continue
|
|
dataset_type += f'{dataset} | '
|
|
|
|
if task_name not in using_dataset.get(dataset, []):
|
|
# only add the metric with the correct dataset
|
|
logger.info(f'task_name ({task_name}) is not in'
|
|
f'{using_dataset.get(dataset, [])}, skip it...')
|
|
continue
|
|
task_type += f'{task_name} | '
|
|
|
|
# remove some metric which not in metric_info from test yaml
|
|
for k, v in pytorch_meta_metric.items():
|
|
|
|
if k not in test_yaml_metric_info and \
|
|
'Restorers' not in task_type:
|
|
continue
|
|
|
|
if 'Restorers' in task_type and k not in dataset_type:
|
|
# mmedit
|
|
continue
|
|
|
|
if isinstance(v, dict):
|
|
# mmedit
|
|
for sub_k, sub_v in v.items():
|
|
use_metric = {sub_k: sub_v}
|
|
metric_list.append(use_metric)
|
|
pytorch_metric.update(use_metric)
|
|
else:
|
|
use_metric = {k: v}
|
|
metric_list.append(use_metric)
|
|
pytorch_metric.update(use_metric)
|
|
|
|
dataset_type = dataset_type[:-3].upper() # remove the final ' | '
|
|
task_type = task_type[:-3] # remove the final ' | '
|
|
|
|
# update useless metric
|
|
metric_all_list = [str(metric) for metric in test_yaml_metric_info]
|
|
metric_useless = set(metric_all_list) - set(
|
|
[str(metric) for metric in pytorch_metric])
|
|
for metric in metric_useless:
|
|
metric_list.append({metric: '-'})
|
|
|
|
# get pytorch fps value
|
|
fps_info = model_info.get('Metadata').get('inference time (ms/im)')
|
|
if fps_info is None:
|
|
fps = '-'
|
|
elif isinstance(fps_info, list):
|
|
fps = fps_info[0].get('value')
|
|
else:
|
|
fps = fps_info.get('value')
|
|
|
|
logger.info(f'Got metric_list = {metric_list} ')
|
|
logger.info(f'Got pytorch_metric = {pytorch_metric} ')
|
|
|
|
# update report
|
|
update_report(
|
|
report_dict=report_dict,
|
|
model_name=model_name,
|
|
model_config=str(model_config_path),
|
|
task_name=task_type,
|
|
checkpoint=str(checkpoint_path),
|
|
dataset=dataset_type,
|
|
backend_name='Pytorch',
|
|
deploy_config='-',
|
|
static_or_dynamic='-',
|
|
precision_type='-',
|
|
conversion_result='-',
|
|
fps=fps,
|
|
metric_info=metric_list,
|
|
test_pass='-',
|
|
report_txt_path=report_txt_path,
|
|
codebase_name=codebase_name)
|
|
|
|
logger.info(f'Got {model_config_path} metric: {pytorch_metric}')
|
|
return pytorch_metric, dataset_type
|
|
|
|
|
|
def get_info_from_log_file(info_type: str, log_path: Path,
|
|
yaml_metric_key: str, logger: logging.Logger):
|
|
"""Get fps and metric result from log file.
|
|
|
|
Args:
|
|
info_type (str): Get which type of info: 'FPS' or 'metric'.
|
|
log_path (Path): Logger path.
|
|
yaml_metric_key (str): Name of metric from yaml metric_key.
|
|
logger (logger.Logger): Logger handler.
|
|
|
|
Returns:
|
|
Float: Info value which get from logger file.
|
|
"""
|
|
if log_path.exists():
|
|
with open(log_path, 'r') as f_log:
|
|
lines = f_log.readlines()
|
|
else:
|
|
logger.warning(f'{log_path} do not exist !!!')
|
|
lines = []
|
|
|
|
if info_type == 'FPS' and len(lines) > 1:
|
|
# Get FPS
|
|
line_count = 0
|
|
fps_sum = 0.00
|
|
fps_lines = lines[1:11]
|
|
|
|
for line in fps_lines:
|
|
if 'FPS' not in line:
|
|
continue
|
|
line_count += 1
|
|
fps_sum += float(line.split(' ')[-2])
|
|
if fps_sum > 0.00:
|
|
info_value = f'{fps_sum / line_count:.2f}'
|
|
else:
|
|
info_value = 'x'
|
|
|
|
elif info_type == 'metric' and len(lines) > 1:
|
|
# To calculate the final line index
|
|
if lines[-1] != '' and lines[-1] != '\n':
|
|
line_index = -1
|
|
else:
|
|
line_index = -2
|
|
if yaml_metric_key == 'mIoU':
|
|
metric_line = lines[-1]
|
|
info_value = metric_line.split('mIoU: ')[1].split(' ')[0]
|
|
info_value = float(info_value)
|
|
return info_value
|
|
elif yaml_metric_key in ['accuracy_top-1', 'Eval-PSNR']:
|
|
# info in last second line
|
|
# mmcls, mmeg, mmedit
|
|
metric_line = lines[line_index - 1]
|
|
elif yaml_metric_key == 'AP':
|
|
# info in last tenth line
|
|
# mmpose
|
|
metric_line = lines[line_index - 9]
|
|
elif yaml_metric_key == 'AR':
|
|
# info in last fifth line
|
|
# mmpose
|
|
metric_line = lines[line_index - 4]
|
|
else:
|
|
# info in final line
|
|
# mmdet
|
|
metric_line = lines[line_index]
|
|
logger.info(f'Got metric_line = {metric_line}')
|
|
|
|
metric_str = \
|
|
metric_line.replace('\n', '').replace('\r', '').split(' - ')[-1]
|
|
logger.info(f'Got metric_str = {metric_str}')
|
|
logger.info(f'Got metric_info = {yaml_metric_key}')
|
|
if 'accuracy_top' in metric_str:
|
|
# mmcls
|
|
metric = eval(metric_str.split(': ')[-1])
|
|
if metric <= 1:
|
|
metric *= 100
|
|
elif yaml_metric_key == 'mIoU' and '|' in metric_str:
|
|
# mmseg
|
|
metric = eval(metric_str.strip().split('|')[2])
|
|
if metric <= 1:
|
|
metric *= 100
|
|
elif yaml_metric_key in ['AP', 'AR']:
|
|
# mmpose
|
|
metric = eval(metric_str.split(': ')[-1])
|
|
elif yaml_metric_key == '0_word_acc_ignore_case' or \
|
|
yaml_metric_key == '0_hmean-iou:hmean':
|
|
# mmocr
|
|
evaluate_result = eval(metric_str)
|
|
if not isinstance(evaluate_result, dict):
|
|
logger.warning(f'Got error metric_dict = {metric_str}')
|
|
return 'x'
|
|
metric = evaluate_result.get(yaml_metric_key, 0.00)
|
|
if yaml_metric_key == '0_word_acc_ignore_case':
|
|
metric *= 100
|
|
elif yaml_metric_key in ['Eval-PSNR', 'Eval-SSIM']:
|
|
# mmedit
|
|
metric = eval(metric_str.split(': ')[-1])
|
|
elif 'bbox' in metric_str:
|
|
# mmdet
|
|
value_list = metric_str.split(' ')
|
|
for value in value_list:
|
|
if yaml_metric_key + ':' in value:
|
|
metric = float(value.split(' ')[-1]) * 100
|
|
break
|
|
else:
|
|
metric = 'x'
|
|
info_value = metric
|
|
else:
|
|
info_value = 'x'
|
|
|
|
return info_value
|
|
|
|
|
|
def compare_metric(metric_value: float, metric_name: str, pytorch_metric: dict,
|
|
metric_info: dict):
|
|
"""Compare metric value with the pytorch metric value and the tolerance.
|
|
|
|
Args:
|
|
metric_value (float): Metric value.
|
|
metric_name (str): metric name.
|
|
pytorch_metric (dict): Pytorch metric which get from metafile.
|
|
metric_info (dict): Metric info from test yaml.
|
|
|
|
Returns:
|
|
Bool: If the test pass or not.
|
|
"""
|
|
if metric_value == 'x':
|
|
return False
|
|
|
|
metric_pytorch = pytorch_metric.get(str(metric_name))
|
|
tolerance_value = metric_info.get(metric_name, {}).get('tolerance', 0.00)
|
|
if (metric_value - tolerance_value) <= metric_pytorch <= \
|
|
(metric_value + tolerance_value):
|
|
test_pass = True
|
|
else:
|
|
test_pass = False
|
|
return test_pass
|
|
|
|
|
|
def get_fps_metric(shell_res: int, pytorch_metric: dict, metric_key: str,
|
|
yaml_metric_info_name: str, log_path: Path,
|
|
metrics_eval_list: dict, metric_info: dict,
|
|
logger: logging.Logger):
|
|
"""Get fps and metric.
|
|
|
|
Args:
|
|
shell_res (int): Backend convert result: 0 is success.
|
|
pytorch_metric (dict): Metric info of pytorch metafile.
|
|
metric_key (str):Metric info.
|
|
yaml_metric_info_name (str): Name of metric info in test yaml.
|
|
log_path (Path): Logger path.
|
|
metrics_eval_list (dict): Metric list from test yaml.
|
|
metric_info (dict): Metric info.
|
|
logger (logger.Logger): Logger handler.
|
|
|
|
Returns:
|
|
Float: fps: FPS of the model.
|
|
List: metric_list: metric result list.
|
|
Bool: test_pass: If the test pass or not.
|
|
"""
|
|
metric_list = []
|
|
|
|
# check if converted successes or not.
|
|
if shell_res != 0:
|
|
fps = 'x'
|
|
metric_value = 'x'
|
|
else:
|
|
# Got fps from log file
|
|
fps = get_info_from_log_file('FPS', log_path, metric_key, logger)
|
|
# logger.info(f'Got fps = {fps}')
|
|
|
|
# Got metric from log file
|
|
metric_value = get_info_from_log_file('metric', log_path, metric_key,
|
|
logger)
|
|
logger.info(f'Got metric = {metric_value}')
|
|
|
|
if yaml_metric_info_name is None:
|
|
logger.error(f'metrics_eval_list: {metrics_eval_list} '
|
|
'has not metric name')
|
|
assert yaml_metric_info_name is not None
|
|
|
|
metric_list.append({yaml_metric_info_name: metric_value})
|
|
test_pass = compare_metric(metric_value, yaml_metric_info_name,
|
|
pytorch_metric, metric_info)
|
|
|
|
# same eval_name and multi metric output in one test
|
|
if yaml_metric_info_name == 'Top 1 Accuracy':
|
|
# mmcls
|
|
yaml_metric_info_name = 'Top 5 Accuracy'
|
|
second_get_metric = True
|
|
elif yaml_metric_info_name == 'AP':
|
|
# mmpose
|
|
yaml_metric_info_name = 'AR'
|
|
second_get_metric = True
|
|
elif yaml_metric_info_name == 'PSNR':
|
|
# mmedit
|
|
yaml_metric_info_name = 'SSIM'
|
|
second_get_metric = True
|
|
else:
|
|
second_get_metric = False
|
|
if second_get_metric:
|
|
metric_key = metric_info.get(yaml_metric_info_name).get('metric_key')
|
|
if shell_res != 0:
|
|
metric_value = 'x'
|
|
else:
|
|
metric_value = get_info_from_log_file('metric', log_path,
|
|
metric_key, logger)
|
|
metric_list.append({yaml_metric_info_name: metric_value})
|
|
if test_pass:
|
|
test_pass = compare_metric(metric_value, yaml_metric_info_name,
|
|
pytorch_metric, metric_info)
|
|
|
|
return fps, metric_list, test_pass
|
|
|
|
|
|
def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path,
|
|
convert_checkpoint_path: str, device_type: str,
|
|
eval_name: str, logger: logging.Logger,
|
|
metrics_eval_list: dict, pytorch_metric: dict,
|
|
metric_info: dict, backend_name: str,
|
|
precision_type: str, metric_useless: set,
|
|
convert_result: bool, report_dict: dict,
|
|
infer_type: str, log_path: Path, dataset_type: str,
|
|
report_txt_path: Path, model_name: str):
|
|
"""Get backend fps and metric.
|
|
|
|
Args:
|
|
deploy_cfg_path (str): Deploy config path.
|
|
model_cfg_path (Path): Model config path.
|
|
convert_checkpoint_path (str): Converted checkpoint path.
|
|
device_type (str): Device for converting.
|
|
eval_name (str): Name of evaluation.
|
|
logger (logging.Logger): Logger handler.
|
|
metrics_eval_list (dict): Evaluation metric info dict.
|
|
pytorch_metric (dict): Pytorch metric info dict get from metafile.
|
|
metric_info (dict): Metric info from test yaml.
|
|
backend_name (str): Backend name.
|
|
precision_type (str): Precision type for evaluation.
|
|
metric_useless (set): Useless metric for specific the model.
|
|
convert_result (bool): Backend convert result.
|
|
report_dict (dict): Backend convert result.
|
|
infer_type (str): Infer type.
|
|
log_path (Path): Logger save path.
|
|
dataset_type (str): Dataset type.
|
|
report_txt_path (Path): report txt save path.
|
|
model_name (str): Name of model in test yaml.
|
|
"""
|
|
cmd_str = 'python3 tools/test.py ' \
|
|
f'{deploy_cfg_path} ' \
|
|
f'{str(model_cfg_path.absolute())} ' \
|
|
f'--model {convert_checkpoint_path} ' \
|
|
f'--log2file "{log_path}" ' \
|
|
f'--speed-test ' \
|
|
f'--device {device_type} '
|
|
|
|
codebase_name = get_codebase(str(deploy_cfg_path)).value
|
|
logger.info(f'Process cmd = {cmd_str}')
|
|
# Test backend
|
|
shell_res = subprocess.run(
|
|
cmd_str, cwd=str(Path(__file__).absolute().parent.parent),
|
|
shell=True).returncode
|
|
logger.info(f'Got shell_res = {shell_res}')
|
|
|
|
metric_key = ''
|
|
metric_name = ''
|
|
task_name = ''
|
|
for key, value in metric_info.items():
|
|
if value.get('eval_name', '') == eval_name:
|
|
metric_name = key
|
|
metric_key = value.get('metric_key', '')
|
|
task_name = value.get('task_name', '')
|
|
break
|
|
|
|
logger.info(f'Got metric_name = {metric_name}')
|
|
logger.info(f'Got metric_key = {metric_key}')
|
|
|
|
fps, metric_list, test_pass = \
|
|
get_fps_metric(shell_res, pytorch_metric, metric_key, metric_name,
|
|
log_path, metrics_eval_list, metric_info, logger)
|
|
|
|
# update useless metric
|
|
for metric in metric_useless:
|
|
metric_list.append({metric: '-'})
|
|
|
|
if len(metrics_eval_list) > 1 and codebase_name == 'mmdet':
|
|
# one model has more than one task, like Mask R-CNN
|
|
for name in pytorch_metric:
|
|
if name in metric_useless or name == metric_name:
|
|
continue
|
|
metric_list.append({name: '-'})
|
|
|
|
# update the report
|
|
update_report(
|
|
report_dict=report_dict,
|
|
model_name=model_name,
|
|
model_config=str(model_cfg_path),
|
|
task_name=task_name,
|
|
checkpoint=convert_checkpoint_path,
|
|
dataset=dataset_type,
|
|
backend_name=backend_name,
|
|
deploy_config=str(deploy_cfg_path),
|
|
static_or_dynamic=infer_type,
|
|
precision_type=precision_type,
|
|
conversion_result=str(convert_result),
|
|
fps=fps,
|
|
metric_info=metric_list,
|
|
test_pass=str(test_pass),
|
|
report_txt_path=report_txt_path,
|
|
codebase_name=codebase_name)
|
|
|
|
|
|
def get_precision_type(deploy_cfg_name: str):
|
|
"""Get backend precision_type according to the name of deploy config.
|
|
|
|
Args:
|
|
deploy_cfg_name (str): Name of the deploy config.
|
|
|
|
Returns:
|
|
Str: precision_type: Precision type of the deployment name.
|
|
"""
|
|
if 'int8' in deploy_cfg_name:
|
|
precision_type = 'int8'
|
|
elif 'fp16' in deploy_cfg_name:
|
|
precision_type = 'fp16'
|
|
else:
|
|
precision_type = 'fp32'
|
|
|
|
return precision_type
|
|
|
|
|
|
def replace_top_in_pipeline_json(backend_output_path: Path,
|
|
logger: logging.Logger):
|
|
"""Replace `topk` with `num_classes` in `pipeline.json`.
|
|
|
|
Args:
|
|
backend_output_path (Path): Backend convert result path.
|
|
logger (logger.Logger): Logger handler.
|
|
"""
|
|
|
|
sdk_pipeline_json_path = backend_output_path.joinpath('pipeline.json')
|
|
sdk_pipeline_json = mmcv.load(sdk_pipeline_json_path)
|
|
|
|
pipeline_tasks = sdk_pipeline_json.get('pipeline', {}).get('tasks', [])
|
|
for index, task in enumerate(pipeline_tasks):
|
|
if task.get('name', '') != 'postprocess':
|
|
continue
|
|
num_classes = task.get('params', {}).get('num_classes', 0)
|
|
if 'topk' not in task.get('params', {}):
|
|
continue
|
|
sdk_pipeline_json['pipeline']['tasks'][index]['params']['topk'] = \
|
|
num_classes
|
|
|
|
logger.info(f'sdk_pipeline_json = {sdk_pipeline_json}')
|
|
|
|
mmcv.dump(
|
|
sdk_pipeline_json, sdk_pipeline_json_path, sort_keys=False, indent=4)
|
|
|
|
logger.info('replace done')
|
|
|
|
|
|
def gen_log_path(backend_output_path: Path, log_name: str):
|
|
log_path = backend_output_path.joinpath(log_name).absolute().resolve()
|
|
if log_path.exists():
|
|
# clear the log file
|
|
with open(log_path, 'w') as f_log:
|
|
f_log.write('')
|
|
|
|
return log_path
|
|
|
|
|
|
def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
|
|
checkpoint_path: Path, work_dir: Path, device_type: str,
|
|
pytorch_metric: dict, metric_info: dict,
|
|
report_dict: dict, test_type: str,
|
|
logger: logging.Logger, backend_file_name: Union[str,
|
|
list],
|
|
report_txt_path: Path, metafile_dataset: str,
|
|
model_name: str):
|
|
"""Convert model to onnx and then get metric.
|
|
|
|
Args:
|
|
pipeline_info (dict): Pipeline info of test yaml.
|
|
model_cfg_path (Path): Model config file path.
|
|
checkpoint_path (Path): Checkpoints path.
|
|
work_dir (Path): A working directory.
|
|
device_type (str): A string specifying device, defaults to 'cuda'.
|
|
pytorch_metric (dict): All pytorch metric info.
|
|
metric_info (dict): Metrics info.
|
|
report_dict (dict): Report info dict.
|
|
test_type (str): Test type. 'precision' or 'convert'.
|
|
logger (logging.Logger): Logger.
|
|
backend_file_name (str | list): backend file save name.
|
|
report_txt_path (Path): report txt path.
|
|
metafile_dataset (str): Dataset type get from metafile.
|
|
model_name (str): Name of model in test yaml.
|
|
"""
|
|
# get backend_test info
|
|
backend_test = pipeline_info.get('backend_test', False)
|
|
|
|
# get convert_image info
|
|
convert_image_info = pipeline_info.get('convert_image', None)
|
|
if convert_image_info is not None:
|
|
input_img_path = \
|
|
convert_image_info.get('input_img', './tests/data/tiger.jpeg')
|
|
test_img_path = convert_image_info.get('test_img', None)
|
|
else:
|
|
input_img_path = './tests/data/tiger.jpeg'
|
|
test_img_path = None
|
|
# get sdk_cfg info
|
|
sdk_config = pipeline_info.get('sdk_config', None)
|
|
if sdk_config is not None:
|
|
sdk_config = Path(sdk_config)
|
|
|
|
# Overwrite metric tolerance
|
|
metric_tolerance = pipeline_info.get('metric_tolerance', None)
|
|
if metric_tolerance is not None:
|
|
for metric, new_tolerance in metric_tolerance.items():
|
|
if metric not in metric_info:
|
|
logger.debug(f'{metric} not in {metric_info},'
|
|
'skip compare it...')
|
|
continue
|
|
if new_tolerance is None:
|
|
logger.debug('new_tolerance is None, skip it ...')
|
|
continue
|
|
metric_info[metric]['tolerance'] = new_tolerance
|
|
if backend_test is False and sdk_config is None:
|
|
test_type = 'convert'
|
|
|
|
metric_name_list = [str(metric) for metric in pytorch_metric]
|
|
assert len(metric_name_list) > 0
|
|
metric_all_list = [str(metric) for metric in metric_info]
|
|
metric_useless = set(metric_all_list) - set(metric_name_list)
|
|
|
|
deploy_cfg_path = Path(pipeline_info.get('deploy_config'))
|
|
backend_name = str(get_backend(str(deploy_cfg_path)).name).lower()
|
|
|
|
# change device_type for special case
|
|
if backend_name in ['ncnn', 'openvino']:
|
|
device_type = 'cpu'
|
|
elif backend_name == 'onnxruntime' and device_type != 'cpu':
|
|
import onnxruntime as ort
|
|
if ort.get_device() != 'GPU':
|
|
device_type = 'cpu'
|
|
logger.warning('Device type is forced to cpu '
|
|
'since onnxruntime-gpu is not installed')
|
|
|
|
infer_type = \
|
|
'dynamic' if is_dynamic_shape(str(deploy_cfg_path)) else 'static'
|
|
|
|
precision_type = get_precision_type(deploy_cfg_path.name)
|
|
codebase_name = get_codebase(str(deploy_cfg_path)).value
|
|
|
|
backend_output_path = Path(work_dir). \
|
|
joinpath(Path(checkpoint_path).parent.parent.name,
|
|
Path(checkpoint_path).parent.name,
|
|
backend_name,
|
|
infer_type,
|
|
precision_type,
|
|
Path(checkpoint_path).stem)
|
|
backend_output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# convert cmd string
|
|
cmd_str = 'python3 ./tools/deploy.py ' \
|
|
f'{str(deploy_cfg_path.absolute().resolve())} ' \
|
|
f'{str(model_cfg_path.absolute().resolve())} ' \
|
|
f'"{str(checkpoint_path.absolute().resolve())}" ' \
|
|
f'"{input_img_path}" ' \
|
|
f'--work-dir "{backend_output_path}" ' \
|
|
f'--device {device_type} ' \
|
|
'--log-level INFO'
|
|
|
|
if sdk_config is not None and test_type == 'precision':
|
|
cmd_str += ' --dump-info'
|
|
|
|
if test_img_path is not None:
|
|
cmd_str += f' --test-img {test_img_path}'
|
|
|
|
if precision_type == 'int8':
|
|
calib_dataset_cfg = pipeline_info.get('calib_dataset_cfg', None)
|
|
if calib_dataset_cfg is not None:
|
|
cmd_str += f' --calib-dataset-cfg {calib_dataset_cfg}'
|
|
|
|
logger.info(f'Process cmd = {cmd_str}')
|
|
|
|
convert_result = False
|
|
convert_log_path = backend_output_path.joinpath('convert_log.log')
|
|
logger.info(f'Logging conversion log to {convert_log_path} ...')
|
|
file_handler = open(convert_log_path, 'w', encoding='utf-8')
|
|
try:
|
|
# Convert the model to specific backend
|
|
process_res = subprocess.Popen(
|
|
cmd_str,
|
|
cwd=str(Path(__file__).absolute().parent.parent),
|
|
shell=True,
|
|
stdout=file_handler,
|
|
stderr=file_handler)
|
|
process_res.wait()
|
|
logger.info(f'Got shell_res = {process_res.returncode}')
|
|
|
|
# check if converted successes or not.
|
|
if process_res.returncode == 0:
|
|
convert_result = True
|
|
else:
|
|
convert_result = False
|
|
|
|
except Exception as e:
|
|
print(f'process convert error: {e}')
|
|
finally:
|
|
file_handler.close()
|
|
|
|
logger.info(f'Got convert_result = {convert_result}')
|
|
|
|
if isinstance(backend_file_name, list):
|
|
convert_checkpoint_path = ''
|
|
for backend_file in backend_file_name:
|
|
backend_path = backend_output_path.joinpath(backend_file)
|
|
backend_path = str(backend_path.absolute().resolve())
|
|
convert_checkpoint_path += f'{str(backend_path)} '
|
|
else:
|
|
convert_checkpoint_path = \
|
|
str(backend_output_path.joinpath(backend_file_name))
|
|
# load deploy_cfg
|
|
deploy_cfg, model_cfg = \
|
|
load_config(str(deploy_cfg_path),
|
|
str(model_cfg_path.absolute()))
|
|
# get dataset type
|
|
if 'dataset_type' in model_cfg:
|
|
dataset_type = \
|
|
str(model_cfg.dataset_type).upper().replace('DATASET', '')
|
|
else:
|
|
dataset_type = metafile_dataset
|
|
# Test the model
|
|
if convert_result and test_type == 'precision':
|
|
# Get evaluation metric from model config
|
|
if codebase_name == 'mmseg':
|
|
metrics_eval_list = model_cfg.val_evaluator.iou_metrics
|
|
else:
|
|
metrics_eval_list = model_cfg.test_evaluator.get('metric', [])
|
|
if isinstance(metrics_eval_list, str):
|
|
# some config is using str only
|
|
metrics_eval_list = [metrics_eval_list]
|
|
|
|
# assert len(metrics_eval_list) > 0
|
|
logger.info(f'Got metrics_eval_list = {metrics_eval_list}')
|
|
if len(metrics_eval_list) == 0 and codebase_name == 'mmedit':
|
|
metrics_eval_list = ['PSNR']
|
|
|
|
# test the model metric
|
|
for metric_name in metrics_eval_list:
|
|
if backend_test:
|
|
log_path = \
|
|
gen_log_path(backend_output_path, 'backend_test.log')
|
|
get_backend_fps_metric(
|
|
deploy_cfg_path=str(deploy_cfg_path),
|
|
model_cfg_path=model_cfg_path,
|
|
convert_checkpoint_path=convert_checkpoint_path,
|
|
device_type=device_type,
|
|
eval_name=metric_name,
|
|
logger=logger,
|
|
metrics_eval_list=metrics_eval_list,
|
|
pytorch_metric=pytorch_metric,
|
|
metric_info=metric_info,
|
|
backend_name=backend_name,
|
|
precision_type=precision_type,
|
|
metric_useless=metric_useless,
|
|
convert_result=convert_result,
|
|
report_dict=report_dict,
|
|
infer_type=infer_type,
|
|
log_path=log_path,
|
|
dataset_type=dataset_type,
|
|
report_txt_path=report_txt_path,
|
|
model_name=model_name)
|
|
|
|
if sdk_config is not None:
|
|
|
|
if codebase_name == 'mmcls':
|
|
replace_top_in_pipeline_json(backend_output_path, logger)
|
|
|
|
log_path = gen_log_path(backend_output_path, 'sdk_test.log')
|
|
# sdk test
|
|
get_backend_fps_metric(
|
|
deploy_cfg_path=str(sdk_config),
|
|
model_cfg_path=model_cfg_path,
|
|
convert_checkpoint_path=str(backend_output_path),
|
|
device_type=device_type,
|
|
eval_name=metric_name,
|
|
logger=logger,
|
|
metrics_eval_list=metrics_eval_list,
|
|
pytorch_metric=pytorch_metric,
|
|
metric_info=metric_info,
|
|
backend_name=f'SDK-{backend_name}',
|
|
precision_type=precision_type,
|
|
metric_useless=metric_useless,
|
|
convert_result=convert_result,
|
|
report_dict=report_dict,
|
|
infer_type=infer_type,
|
|
log_path=log_path,
|
|
dataset_type=dataset_type,
|
|
report_txt_path=report_txt_path,
|
|
model_name=model_name)
|
|
else:
|
|
logger.info('Only test convert, saving to report...')
|
|
metric_list = []
|
|
fps = '-'
|
|
|
|
task_name = ''
|
|
for metric in metric_name_list:
|
|
metric_list.append({metric: '-'})
|
|
metric_task_name = metric_info.get(metric, {}).get('task_name', '')
|
|
if metric_task_name in task_name:
|
|
logger.debug('metric_task_name exist, skip for adding it...')
|
|
continue
|
|
task_name += f'{metric_task_name} | '
|
|
if ' | ' == task_name[-3:]:
|
|
task_name = task_name[:-3]
|
|
test_pass = True if convert_result else False
|
|
|
|
# update useless metric
|
|
for metric in metric_useless:
|
|
metric_list.append({metric: '-'})
|
|
|
|
if convert_result:
|
|
report_checkpoint = convert_checkpoint_path
|
|
else:
|
|
report_checkpoint = str(checkpoint_path)
|
|
|
|
# update the report
|
|
update_report(
|
|
report_dict=report_dict,
|
|
model_name=model_name,
|
|
model_config=str(model_cfg_path),
|
|
task_name=task_name,
|
|
checkpoint=report_checkpoint,
|
|
dataset=dataset_type,
|
|
backend_name=backend_name,
|
|
deploy_config=str(deploy_cfg_path),
|
|
static_or_dynamic=infer_type,
|
|
precision_type=precision_type,
|
|
conversion_result=str(convert_result),
|
|
fps=fps,
|
|
metric_info=metric_list,
|
|
test_pass=str(test_pass),
|
|
report_txt_path=report_txt_path,
|
|
codebase_name=codebase_name)
|
|
|
|
|
|
def save_report(report_info: dict, report_save_path: Path,
|
|
logger: logging.Logger):
|
|
"""Convert model to onnx and then get metric.
|
|
|
|
Args:
|
|
report_info (dict): Report info dict.
|
|
report_save_path (Path): Report save path.
|
|
logger (logging.Logger): Logger.
|
|
"""
|
|
logger.info('Saving regression test report to '
|
|
f'{report_save_path.absolute().resolve()}, pls wait...')
|
|
try:
|
|
df = pd.DataFrame(report_info)
|
|
df.to_excel(report_save_path)
|
|
except ValueError:
|
|
logger.info(f'Got error report_info = {report_info}')
|
|
|
|
logger.info('Saved regression test report to '
|
|
f'{report_save_path.absolute().resolve()}.')
|
|
|
|
|
|
def _filter_string(inputs):
|
|
"""Remove non alpha&number character from input string.
|
|
|
|
Args:
|
|
inputs(str): Input string.
|
|
|
|
Returns:
|
|
str: Output of only alpha&number string.
|
|
"""
|
|
outputs = ''.join([i.lower() for i in inputs if i.isalnum()])
|
|
return outputs
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
set_start_method('spawn')
|
|
logger = get_root_logger(log_level=args.log_level)
|
|
|
|
test_type = 'precision' if args.performance else 'convert'
|
|
logger.info(f'Processing regression test in {test_type} mode.')
|
|
|
|
backend_file_info = {
|
|
'onnxruntime': 'end2end.onnx',
|
|
'tensorrt': 'end2end.engine',
|
|
'openvino': 'end2end.xml',
|
|
'ncnn': ['end2end.param', 'end2end.bin'],
|
|
'pplnn': ['end2end.onnx', 'end2end.json'],
|
|
'torchscript': 'end2end.pt'
|
|
}
|
|
|
|
backend_list = args.backends
|
|
if backend_list is None:
|
|
backend_list = [
|
|
'onnxruntime', 'tensorrt', 'openvino', 'ncnn', 'pplnn',
|
|
'torchscript'
|
|
]
|
|
assert isinstance(backend_list, list)
|
|
logger.info(f'Regression test backend list = {backend_list}')
|
|
|
|
if args.models is None:
|
|
logger.info('Regression test for all models in test yaml.')
|
|
else:
|
|
args.models = tuple([_filter_string(s) for s in args.models])
|
|
logger.info(f'Regression test models list = {args.models}')
|
|
|
|
assert ' ' not in args.work_dir, \
|
|
f'No empty space included in {args.work_dir}'
|
|
assert ' ' not in args.checkpoint_dir, \
|
|
f'No empty space included in {args.checkpoint_dir}'
|
|
|
|
work_dir = Path(args.work_dir)
|
|
work_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
deploy_yaml_list = [
|
|
f'./tests/regression/{codebase}.yml' for codebase in args.codebase
|
|
]
|
|
|
|
for deploy_yaml in deploy_yaml_list:
|
|
|
|
if not Path(deploy_yaml).exists():
|
|
raise FileNotFoundError(f'deploy_yaml {deploy_yaml} not found, '
|
|
'please check !')
|
|
|
|
with open(deploy_yaml) as f:
|
|
yaml_info = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
report_save_path = \
|
|
work_dir.joinpath(Path(deploy_yaml).stem + '_report.xlsx')
|
|
report_txt_path = report_save_path.with_suffix('.txt')
|
|
|
|
report_dict = {
|
|
'Model': [],
|
|
'Model Config': [],
|
|
'Task': [],
|
|
'Checkpoint': [],
|
|
'Dataset': [],
|
|
'Backend': [],
|
|
'Deploy Config': [],
|
|
'Static or Dynamic': [],
|
|
'Precision Type': [],
|
|
'Conversion Result': [],
|
|
# 'FPS': []
|
|
}
|
|
|
|
global_info = yaml_info.get('globals')
|
|
metric_info = global_info.get('metric_info', {})
|
|
for metric_name in metric_info:
|
|
report_dict.update({metric_name: []})
|
|
report_dict.update({'Test Pass': []})
|
|
|
|
global_info.update({'checkpoint_dir': args.checkpoint_dir})
|
|
global_info.update(
|
|
{'codebase_name': Path(deploy_yaml).stem.split('_')[0]})
|
|
|
|
with open(report_txt_path, 'w') as f_report:
|
|
title_str = ''
|
|
for key in report_dict:
|
|
title_str += f'{key},'
|
|
title_str = title_str[:-1] + '\n'
|
|
f_report.write(title_str) # clear the report tmp file
|
|
|
|
models_info = yaml_info.get('models')
|
|
for models in models_info:
|
|
model_name_origin = models.get('name', 'model')
|
|
model_name_new = _filter_string(model_name_origin)
|
|
if 'model_configs' not in models:
|
|
logger.warning('Can not find field "model_configs", '
|
|
f'skipping {model_name_origin}...')
|
|
continue
|
|
|
|
if args.models is not None and model_name_new not in args.models:
|
|
logger.info(
|
|
f'Test specific model mode, skip {model_name_origin}...')
|
|
continue
|
|
|
|
model_metafile_info, checkpoint_save_dir, codebase_dir = \
|
|
get_model_metafile_info(global_info, models, logger)
|
|
for model_config in model_metafile_info:
|
|
logger.info(f'Processing test for {model_config}...')
|
|
|
|
# Get backends info
|
|
pipelines_info = models.get('pipelines', None)
|
|
if pipelines_info is None:
|
|
logger.warning('pipelines_info is None, skip it...')
|
|
continue
|
|
|
|
# Get model config path
|
|
model_cfg_path = Path(codebase_dir).joinpath(model_config)
|
|
assert model_cfg_path.exists()
|
|
|
|
# Get checkpoint path
|
|
checkpoint_name = Path(
|
|
model_metafile_info.get(model_config).get('Weights')).name
|
|
|
|
checkpoint_path = Path(checkpoint_save_dir, checkpoint_name)
|
|
assert checkpoint_path.exists()
|
|
|
|
# Get pytorch from metafile.yml
|
|
pytorch_metric, metafile_dataset = get_pytorch_result(
|
|
model_name_origin, model_metafile_info, checkpoint_path,
|
|
model_cfg_path, model_config, metric_info, report_dict,
|
|
logger, report_txt_path, global_info.get('codebase_name'))
|
|
for pipeline in pipelines_info:
|
|
deploy_config = pipeline.get('deploy_config')
|
|
backend_name = get_backend(deploy_config).name.lower()
|
|
if backend_name not in backend_list:
|
|
logger.warning(f'backend_name ({backend_name}) not '
|
|
f'in {backend_list}, skip it...')
|
|
continue
|
|
|
|
backend_file_name = \
|
|
backend_file_info.get(backend_name, None)
|
|
if backend_file_name is None:
|
|
logger.warning('backend_file_name is None, '
|
|
'skip it...')
|
|
continue
|
|
|
|
get_backend_result(pipeline, model_cfg_path,
|
|
checkpoint_path, work_dir, args.device,
|
|
pytorch_metric, metric_info,
|
|
report_dict, test_type, logger,
|
|
backend_file_name, report_txt_path,
|
|
metafile_dataset, model_name_origin)
|
|
if len(report_dict.get('Model')) > 0:
|
|
save_report(report_dict, report_save_path, logger)
|
|
else:
|
|
logger.info(f'No model for {deploy_yaml}, not saving report.')
|
|
|
|
# merge report
|
|
merge_report(str(work_dir), logger)
|
|
|
|
logger.info('All done.')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|