fix reg test for 2.0 ()

* fix reg

fix mmedit mmocr

update

update apcnet of mmseg

catch error when download fails

log error message

* update reg for win

* update yml
pull/1169/head
RunningLeon 2022-10-08 19:14:33 +08:00 committed by GitHub
parent 2e459dba6d
commit 51f4e65185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 253 additions and 544 deletions

View File

@ -3,25 +3,15 @@ globals:
checkpoint_force_download: False
images:
img_snake: &img_snake ../mmclassification/demo/demo.JPEG
img_bird: &img_bird ../mmclassification/demo/bird.JPEG
img_cat_dog: &img_cat_dog ../mmclassification/demo/cat-dog.png
img_dog: &img_dog ../mmclassification/demo/dog.jpg
img_color_cat: &img_color_cat ../mmclassification/tests/data/color.jpg
img_gray_cat: &img_gray_cat ../mmclassification/tests/data/gray.jpg
metric_info: &metric_info
Top 1 Accuracy: # named after metafile.Results.Metrics
eval_name: accuracy # test.py --metrics args
metric_key: accuracy_top-1 # eval Dict key name
metric_key: accuracy/top1 # key name in output json
tolerance: 1 # metric ±n%
task_name: Image Classification # metafile.Results.Task
dataset: ImageNet-1k # metafile.Results.Dataset
Top 5 Accuracy:
eval_name: accuracy
metric_key: accuracy_top-5
metric_key: accuracy/top5
tolerance: 1 # metric ±n%
task_name: Image Classification
dataset: ImageNet-1k
convert_image: &convert_image
input_img: *img_snake
test_img: *img_color_cat

View File

@ -4,26 +4,18 @@ globals:
images:
input_img: &input_img ../mmdetection/demo/demo.jpg
test_img: &test_img ./tests/data/tiger.jpeg
img_blank: &img_blank
metric_info: &metric_info
box AP: # named after metafile.Results.Metrics
eval_name: bbox # test.py --metrics args
metric_key: bbox_mAP # eval OrderedDict key name
metric_key: coco/bbox_mAP # eval OrderedDict key name
tolerance: 0.2 # metric ±n%
task_name: Object Detection # metafile.Results.Task
dataset: COCO # metafile.Results.Dataset
multi_value: 100
mask AP:
eval_name: segm
metric_key: segm_mAP
metric_key: coco/segm_mAP
tolerance: 1 # metric ±n%
task_name: Instance Segmentation
dataset: COCO
multi_value: 100
PQ:
eval_name: proposal
metric_key: '?'
tolerance: 0.1 # metric ±n%
task_name: Panoptic Segmentation
dataset: COCO
convert_image: &convert_image
input_img: *input_img
test_img: *test_img
@ -303,7 +295,7 @@ models:
- name: Mask R-CNN
metafile: configs/mask_rcnn/metafile.yml
model_configs:
- configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py
- configs/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py
pipelines:
- *pipeline_seg_ts_fp32
- *pipeline_seg_ort_dynamic_fp32

View File

@ -5,42 +5,18 @@ globals:
img_face: &img_face ../mmediting/tests/data/image/face/000001.png
img_bg: &img_bg ../mmediting/tests/data/image/gt/baboon.png
metric_info: &metric_info
DIV2K PSNR: # named after metafile.Results.Metrics
eval_name: PSNR # test.py --metrics args
metric_key: DIV2K PSNR # eval log key name
tolerance: 4 # metric ±n%
task_name: Srcnn # metafile.Results.Task
dataset: DIV2K # metafile.Results.Dataset
DIV2K SSIM:
eval_name: SSIM
metric_key: DIV2K SSIM
tolerance: 0.02 # metric ±n
task_name: Srcnn
dataset: DIV2K
Set14 PSNR: # named after metafile.Results.Metrics
eval_name: PSNR # test.py --metrics args
metric_key: Set14 PSNR # eval log key name
metric_key: Set14/PSNR # eval log key name
tolerance: 4 # metric ±n%
task_name: Srcnn # metafile.Results.Task
dataset: Set14 # metafile.Results.Dataset
Set14 SSIM:
eval_name: SSIM
metric_key: Set14 SSIM
metric_key: Set14/SSIM
tolerance: 0.02 # metric ±n
task_name: Srcnn
dataset: Set14
Set5 PSNR: # named after metafile.Results.Metrics
eval_name: PSNR # test.py --metrics args
metric_key: Set5 PSNR # eval log key name
metric_key: Set5/PSNR # eval log key name
tolerance: 4 # metric ±n%
task_name: Srcnn # metafile.Results.Task
dataset: Set5 # metafile.Results.Dataset
Set5 SSIM:
eval_name: SSIM
metric_key: Set5 SSIM
metric_key: Set5/SSIM
tolerance: 0.02 # metric ±n
task_name: Srcnn
dataset: Set5
convert_image: &convert_image
input_img: *img_face
test_img: *img_bg

View File

@ -4,21 +4,14 @@ globals:
images:
img_densetext_det: &img_densetext_det ../mmocr/demo/demo_densetext_det.jpg
img_demo_text_det: &img_demo_text_det ../mmocr/demo/demo_text_det.jpg
img_demo_text_ocr: &img_demo_text_ocr ../mmocr/demo/demo_text_ocr.jpg
img_demo_text_recog: &img_demo_text_recog ../mmocr/demo/demo_text_recog.jpg
metric_info: &metric_info
hmean-iou: # named after metafile.Results.Metrics
eval_name: hmean-iou # test.py --metrics args
metric_key: 0_hmean-iou:hmean # eval key name
metric_key: icdar/hmean # eval key name
tolerance: 0.15 # metric ±n%
task_name: Text Detection # metafile.Results.Task
dataset: ICDAR2015 # metafile.Results.Dataset
word_acc:
eval_name: acc
metric_key: 0_word_acc_ignore_case
metric_key: IIIT5K/recog/word_acc_ignore_case_symbol
tolerance: 0.05 # metric ±n%
task_name: Text Recognition
dataset: IIIT5K
convert_image_det: &convert_image_det
input_img: *img_densetext_det
test_img: *img_demo_text_det

View File

@ -6,17 +6,11 @@ globals:
img_human_pose_256x192: &img_human_pose_256x192 ./demo/resources/human-pose.jpg
metric_info: &metric_info
AP: # named after metafile.Results.Metrics
eval_name: mAP # test.py --metrics args
metric_key: AP # eval key name
tolerance: 0.10 # metric ±n
task_name: Body 2D Keypoint # metafile.Results.Task
dataset: COCO # metafile.Results.Dataset
metric_key: coco/AP # eval key name
tolerance: 0.02 # metric ±n
AR:
eval_name: mAP
metric_key: AR
tolerance: 0.08 # metric ±n
task_name: Body 2D Keypoint
dataset: COCO
metric_key: coco/AR
tolerance: 0.02 # metric ±n
convert_image: &convert_image
input_img: *img_human_pose
test_img: *img_human_pose_256x192

View File

@ -6,11 +6,8 @@ globals:
img_dota_demo: &img_dota_demo ../mmrotate/demo/dota_demo.jpg
metric_info: &metric_info
mAP: # named after metafile.Results.Metrics
eval_name: mAP # test.py --metrics args
metric_key: AP # eval key name
tolerance: 0.10 # metric ±n%
task_name: Oriented Object Detection # metafile.Results.Task
dataset: DOTAv1.0 # metafile.Results.Dataset
convert_image_det: &convert_image_det
input_img: *img_demo
test_img: *img_dota_demo

View File

@ -7,11 +7,8 @@ globals:
metric_info: &metric_info
mIoU: # named after metafile.Results.Metrics
eval_name: mIoU # test.py --metrics args
metric_key: mIoU # eval OrderedDict key name
tolerance: 5 # metric ±n%
task_name: Semantic Segmentation # metafile.Results.Task
dataset: [Cityscapes, ADE20K] # metafile.Results.Dataset
tolerance: 1 # metric ±n%
convert_image: &convert_image
input_img: *img_leftImg8bit
test_img: *img_loveda_0
@ -208,8 +205,8 @@ models:
model_configs:
- configs/apcnet/apcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16
- *pipeline_ort_static_fp32
- *pipeline_trt_static_fp16
- *pipeline_ncnn_static_fp32
- *pipeline_ts_fp32
@ -272,7 +269,7 @@ models:
- configs/erfnet/erfnet_fcn_4xb4-160k_cityscapes-512x1024.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16
- *pipeline_trt_dynamic_fp32
- *pipeline_ncnn_static_fp32
- *pipeline_openvino_dynamic_fp32
- *pipeline_ts_fp32
@ -370,7 +367,6 @@ models:
metafile: configs/segmenter/segmenter.yml
model_configs:
- configs/segmenter/segmenter_vit-s_fcn_8xb1-160k_ade20k-512x512.py
- configs/segmenter/segmenter_vit-s_mask_8xb1-160k_ade20k-512x512.py
pipelines:
- *pipeline_ort_static_fp32_512x512
- *pipeline_trt_static_fp32_512x512

View File

@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import glob
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Union
from typing import List, Union
import mmengine
import openpyxl
@ -14,7 +17,7 @@ from torch.multiprocessing import set_start_method
import mmdeploy.version
from mmdeploy.utils import (get_backend, get_codebase, get_root_logger,
is_dynamic_shape, load_config)
is_dynamic_shape)
def parse_args():
@ -297,84 +300,33 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
return {}
# get metric
model_info = meta_info.get(model_config_name, None)
metafile_metric_info = model_info.get('Results', None)
metric_list = []
model_info = meta_info[model_config_name]
metafile_metric_info = model_info['Results']
pytorch_metric = dict()
dataset_type = ''
task_type = ''
# Get dataset
using_dataset = dict()
for _, v in test_yaml_metric_info.items():
if v.get('dataset') is None:
continue
dataset_list = v.get('dataset', [])
if not isinstance(dataset_list, list):
dataset_list = [dataset_list]
for metric_dataset in dataset_list:
dataset_tmp = using_dataset.get(metric_dataset, [])
if isinstance(v.get('task_name'), list):
tasklist = v.get('task_name')
for task in tasklist:
if task not in dataset_tmp:
dataset_tmp.append(task)
else:
if v.get('task_name') not in dataset_tmp:
dataset_tmp.append(v.get('task_name'))
using_dataset.update({metric_dataset: dataset_tmp})
using_dataset = set()
using_task = set()
# Get metrics info from metafile
for metafile_metric in metafile_metric_info:
pytorch_meta_metric = metafile_metric.get('Metrics')
dataset = metafile_metric.get('Dataset', '')
task_name = metafile_metric.get('Task', '')
if task_name == 'Restorers':
# mmedit
dataset = 'Set5'
pytorch_metric.update(metafile_metric['Metrics'])
dataset = metafile_metric['Dataset']
task_name = metafile_metric['Task']
if task_name not in using_task:
using_task.add(task_name)
if dataset not in using_dataset:
using_dataset.add(dataset)
if (len(using_dataset) > 1) and (dataset not in using_dataset):
logger.info(f'dataset not in {using_dataset}, skip it...')
continue
dataset_type += f'{dataset} | '
if task_name not in using_dataset.get(dataset, []):
# only add the metric with the correct dataset
logger.info(f'task_name ({task_name}) is not in'
f'{using_dataset.get(dataset, [])}, skip it...')
continue
task_type += f'{task_name} | '
# remove some metric which not in metric_info from test yaml
for k, v in pytorch_meta_metric.items():
if k not in test_yaml_metric_info and \
'Restorers' not in task_type:
continue
if 'Restorers' in task_type and k not in dataset_type:
# mmedit
continue
if isinstance(v, dict):
# mmedit
for sub_k, sub_v in v.items():
use_metric = {sub_k: sub_v}
metric_list.append(use_metric)
pytorch_metric.update(use_metric)
else:
use_metric = {k: v}
metric_list.append(use_metric)
pytorch_metric.update(use_metric)
dataset_type = dataset_type[:-3].upper() # remove the final ' | '
task_type = task_type[:-3] # remove the final ' | '
# update useless metric
metric_all_list = [str(metric) for metric in test_yaml_metric_info]
metric_useless = set(metric_all_list) - set(
[str(metric) for metric in pytorch_metric])
for metric in metric_useless:
metric_list.append({metric: '-'})
dataset_type = '|'.join(list(using_dataset))
task_type = '|'.join(list(using_task))
metric_list = []
for metric in test_yaml_metric_info:
value = '-'
if metric in pytorch_metric:
value = pytorch_metric[metric]
metric_list.append({metric: value})
valid_pytorch_metric = {
k: v
for k, v in pytorch_metric.items() if k in test_yaml_metric_info
}
# get pytorch fps value
fps_info = model_info.get('Metadata').get('inference time (ms/im)')
@ -407,232 +359,97 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
report_txt_path=report_txt_path,
codebase_name=codebase_name)
logger.info(f'Got {model_config_path} metric: {pytorch_metric}')
return pytorch_metric, dataset_type
logger.info(f'Got {model_config_path} metric: {valid_pytorch_metric}')
dataset_info = dict(dataset=dataset_type, task=task_type)
return valid_pytorch_metric, dataset_info
def get_info_from_log_file(info_type: str, log_path: Path,
yaml_metric_key: str, logger: logging.Logger):
"""Get fps and metric result from log file.
def parse_test_log(work_dir: str) -> dict:
"""Parse metrics result from output json file.
Args:
info_type (str): Get which type of info: 'FPS' or 'metric'.
log_path (Path): Logger path.
yaml_metric_key (str): Name of metric from yaml metric_key.
logger (logger.Logger): Logger handler.
work_dir: work directory that has output json file.
Returns:
Float: Info value which get from logger file.
dict: metric results
"""
if log_path.exists():
with open(log_path, 'r') as f_log:
lines = f_log.readlines()
logger = get_root_logger()
json_files = glob.glob(os.path.join(work_dir, '*', '*.json'))
json_path = None
newest_date = None
# filter json and get latest json file
for f in json_files:
fname = os.path.split(f)[1].strip('.json')
try:
date = datetime.strptime(fname, '%Y%m%d_%H%M%S')
if newest_date is None:
newest_date = date
json_path = f
elif date > newest_date:
newest_date = date
json_path = f
except Exception:
pass
if (not os.path.exists(work_dir)) or json_path is None:
logger.warning(f'Not json files found in {work_dir}')
result = {}
else:
logger.warning(f'{log_path} do not exist !!!')
lines = []
if info_type == 'FPS' and len(lines) > 1:
# Get FPS
line_count = 0
fps_sum = 0.00
fps_lines = lines[1:11]
for line in fps_lines:
if 'FPS' not in line:
continue
line_count += 1
fps_sum += float(line.split(' ')[-2])
if fps_sum > 0.00:
info_value = f'{fps_sum / line_count:.2f}'
else:
info_value = 'x'
elif info_type == 'metric' and len(lines) > 1:
# To calculate the final line index
if lines[-1] != '' and lines[-1] != '\n':
line_index = -1
else:
line_index = -2
if yaml_metric_key == 'mIoU':
metric_line = lines[-1]
info_value = metric_line.split('mIoU: ')[1].split(' ')[0]
info_value = float(info_value)
return info_value
elif yaml_metric_key in ['accuracy_top-1', 'Eval-PSNR']:
# info in last second line
# mmcls, mmeg, mmedit
metric_line = lines[line_index - 1]
elif yaml_metric_key == 'AP':
# info in last tenth line
# mmpose
metric_line = lines[line_index - 9]
elif yaml_metric_key == 'AR':
# info in last fifth line
# mmpose
metric_line = lines[line_index - 4]
else:
# info in final line
# mmdet
metric_line = lines[line_index]
logger.info(f'Got metric_line = {metric_line}')
metric_str = \
metric_line.replace('\n', '').replace('\r', '').split(' - ')[-1]
logger.info(f'Got metric_str = {metric_str}')
logger.info(f'Got metric_info = {yaml_metric_key}')
if 'accuracy_top' in metric_str:
# mmcls
metric = eval(metric_str.split(': ')[-1])
if metric <= 1:
metric *= 100
elif yaml_metric_key == 'mIoU' and '|' in metric_str:
# mmseg
metric = eval(metric_str.strip().split('|')[2])
if metric <= 1:
metric *= 100
elif yaml_metric_key in ['AP', 'AR']:
# mmpose
metric = eval(metric_str.split(': ')[-1])
elif yaml_metric_key == '0_word_acc_ignore_case' or \
yaml_metric_key == '0_hmean-iou:hmean':
# mmocr
evaluate_result = eval(metric_str)
if not isinstance(evaluate_result, dict):
logger.warning(f'Got error metric_dict = {metric_str}')
return 'x'
metric = evaluate_result.get(yaml_metric_key, 0.00)
if yaml_metric_key == '0_word_acc_ignore_case':
metric *= 100
elif 'PSNR' in yaml_metric_key or 'SSIM' in yaml_metric_key:
# mmedit
metric = eval(metric_str.split(': ')[-1])
elif 'bbox' in metric_str:
# mmdet
value_list = metric_str.split(' ')
for value in value_list:
if yaml_metric_key + ':' in value:
metric = float(value.split(' ')[-1]) * 100
break
else:
metric = 'x'
info_value = metric
else:
info_value = 'x'
return info_value
logger.info(f'Parse test result from {json_path}')
result = mmengine.load(json_path)
return result
def compare_metric(metric_value: float, metric_name: str, pytorch_metric: dict,
metric_info: dict):
"""Compare metric value with the pytorch metric value and the tolerance.
Args:
metric_value (float): Metric value.
metric_name (str): metric name.
pytorch_metric (dict): Pytorch metric which get from metafile.
metric_info (dict): Metric info from test yaml.
Returns:
Bool: If the test pass or not.
"""
if metric_value == 'x':
return False
metric_pytorch = pytorch_metric.get(str(metric_name))
tolerance_value = metric_info.get(metric_name, {}).get('tolerance', 0.00)
if (metric_value - tolerance_value) <= metric_pytorch <= \
(metric_value + tolerance_value):
test_pass = True
else:
test_pass = False
return test_pass
def get_fps_metric(shell_res: int, pytorch_metric: dict, metric_key: str,
yaml_metric_info_name: str, log_path: Path,
metrics_eval_list: dict, metric_info: dict,
logger: logging.Logger):
def get_fps_metric(shell_res: int, pytorch_metric: dict, metric_info: dict,
work_path: Path):
"""Get fps and metric.
Args:
shell_res (int): Backend convert result: 0 is success.
pytorch_metric (dict): Metric info of pytorch metafile.
metric_key (str):Metric info.
yaml_metric_info_name (str): Name of metric info in test yaml.
log_path (Path): Logger path.
metrics_eval_list (dict): Metric list from test yaml.
work_path (Path): Logger path.
metric_info (dict): Metric info.
logger (logger.Logger): Logger handler.
Returns:
Float: fps: FPS of the model.
List: metric_list: metric result list.
Bool: test_pass: If the test pass or not.
"""
metric_list = []
# check if converted successes or not.
fps = 'x'
if shell_res != 0:
fps = 'x'
metric_value = 'x'
backend_results = {}
else:
# Got fps from log file
fps = get_info_from_log_file('FPS', log_path, metric_key, logger)
# logger.info(f'Got fps = {fps}')
backend_results = parse_test_log(work_path)
compare_results = {}
output_result = {}
for metric_name, metric_value in pytorch_metric.items():
metric_key = metric_info[metric_name]['metric_key']
tolerance = metric_info[metric_name]['tolerance']
multi_value = metric_info[metric_name].get('multi_value', 1.0)
compare_flag = True
output_result[metric_name] = '-'
if metric_key in backend_results:
backend_value = backend_results[metric_key] * multi_value
output_result[metric_name] = backend_value
if backend_value < metric_value - tolerance:
compare_flag = False
compare_results[metric_name] = compare_flag
# Got metric from log file
metric_value = get_info_from_log_file('metric', log_path, metric_key,
logger)
logger.info(f'Got metric = {metric_value}')
if yaml_metric_info_name is None:
logger.error(f'metrics_eval_list: {metrics_eval_list} '
'has not metric name')
assert yaml_metric_info_name is not None
metric_list.append({yaml_metric_info_name: metric_value})
test_pass = compare_metric(metric_value, yaml_metric_info_name,
pytorch_metric, metric_info)
# same eval_name and multi metric output in one test
if yaml_metric_info_name == 'Top 1 Accuracy':
# mmcls
yaml_metric_info_name = 'Top 5 Accuracy'
second_get_metric = True
elif yaml_metric_info_name == 'AP':
# mmpose
yaml_metric_info_name = 'AR'
second_get_metric = True
elif yaml_metric_info_name == 'PSNR':
# mmedit
yaml_metric_info_name = 'SSIM'
second_get_metric = True
if len(compare_results):
test_pass = all(list(compare_results.values()))
else:
second_get_metric = False
if second_get_metric:
metric_key = metric_info.get(yaml_metric_info_name).get('metric_key')
if shell_res != 0:
metric_value = 'x'
else:
metric_value = get_info_from_log_file('metric', log_path,
metric_key, logger)
metric_list.append({yaml_metric_info_name: metric_value})
if test_pass:
test_pass = compare_metric(metric_value, yaml_metric_info_name,
pytorch_metric, metric_info)
return fps, metric_list, test_pass
test_pass = False
return fps, output_result, test_pass
def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path,
convert_checkpoint_path: str, device_type: str,
eval_name: str, logger: logging.Logger,
metrics_eval_list: dict, pytorch_metric: dict,
logger: logging.Logger, pytorch_metric: dict,
metric_info: dict, backend_name: str,
precision_type: str, metric_useless: set,
convert_result: bool, report_dict: dict,
infer_type: str, log_path: Path, dataset_type: str,
report_txt_path: Path, model_name: str):
precision_type: str, convert_result: bool,
report_dict: dict, infer_type: str, log_path: Path,
dataset_info: dict, report_txt_path: Path,
model_name: str):
"""Get backend fps and metric.
Args:
@ -640,66 +457,44 @@ def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path,
model_cfg_path (Path): Model config path.
convert_checkpoint_path (str): Converted checkpoint path.
device_type (str): Device for converting.
eval_name (str): Name of evaluation.
logger (logging.Logger): Logger handler.
metrics_eval_list (dict): Evaluation metric info dict.
pytorch_metric (dict): Pytorch metric info dict get from metafile.
metric_info (dict): Metric info from test yaml.
backend_name (str): Backend name.
precision_type (str): Precision type for evaluation.
metric_useless (set): Useless metric for specific the model.
convert_result (bool): Backend convert result.
report_dict (dict): Backend convert result.
infer_type (str): Infer type.
log_path (Path): Logger save path.
dataset_type (str): Dataset type.
dataset_info (dict): Dataset info.
report_txt_path (Path): report txt save path.
model_name (str): Name of model in test yaml.
"""
cmd_str = 'python3 tools/test.py ' \
f'{deploy_cfg_path} ' \
f'{str(model_cfg_path.absolute())} ' \
f'--model {convert_checkpoint_path} ' \
f'--log2file "{log_path}" ' \
f'--speed-test ' \
f'--device {device_type} '
work_dir = log_path.parent.joinpath('test_logs')
if not work_dir.exists():
work_dir.mkdir(parents=True, exist_ok=True)
cmd_lines = [
'python3 tools/test.py', f'{deploy_cfg_path}',
f'{str(model_cfg_path.absolute())}',
f'--model {convert_checkpoint_path}', f'--work-dir "{work_dir}"',
'--speed-test', f'--device {device_type}'
]
codebase_name = get_codebase(str(deploy_cfg_path)).value
logger.info(f'Process cmd = {cmd_str}')
# Test backend
shell_res = subprocess.run(
cmd_str, cwd=str(Path(__file__).absolute().parent.parent),
shell=True).returncode
logger.info(f'Got shell_res = {shell_res}')
metric_key = ''
metric_name = ''
task_name = ''
for key, value in metric_info.items():
if value.get('eval_name', '') == eval_name:
metric_name = key
metric_key = value.get('metric_key', '')
task_name = value.get('task_name', '')
break
logger.info(f'Got metric_name = {metric_name}')
logger.info(f'Got metric_key = {metric_key}')
fps, metric_list, test_pass = \
get_fps_metric(shell_res, pytorch_metric, metric_key, metric_name,
log_path, metrics_eval_list, metric_info, logger)
# update useless metric
for metric in metric_useless:
metric_list.append({metric: '-'})
if len(metrics_eval_list) > 1 and codebase_name == 'mmdet':
# one model has more than one task, like Mask R-CNN
for name in pytorch_metric:
if name in metric_useless or name == metric_name:
continue
metric_list.append({name: '-'})
return_code = run_cmd(cmd_lines, log_path)
fps, backend_metric, test_pass = get_fps_metric(return_code,
pytorch_metric,
metric_info, work_dir)
logger.info(f'test_pass={test_pass}, results{backend_metric}')
metric_list = []
for metric in metric_info:
value = '-'
if metric in backend_metric:
value = backend_metric[metric]
metric_list.append({metric: value})
dataset_type = dataset_info['dataset']
task_name = dataset_info['task']
# update the report
update_report(
report_dict=report_dict,
@ -770,15 +565,61 @@ def replace_top_in_pipeline_json(backend_output_path: Path,
def gen_log_path(backend_output_path: Path, log_name: str):
if not backend_output_path.exists():
backend_output_path.mkdir(parents=True, exist_ok=True)
log_path = backend_output_path.joinpath(log_name).absolute().resolve()
if log_path.exists():
# clear the log file
with open(log_path, 'w') as f_log:
f_log.write('')
os.remove(str(log_path))
return log_path
def run_cmd(cmd_lines: List[str], log_path: Path):
"""
Args:
cmd_lines: (list[str]): A command in multiple line style.
log_path (Path): Path to log file.
Returns:
int: error code.
"""
import platform
system = platform.system().lower()
if system == 'windows':
sep = r'`'
else: # 'Linux', 'Darwin'
sep = '\\'
cmd_for_run = ' '.join(cmd_lines)
cmd_for_log = f' {sep}\n'.join(cmd_lines) + '\n'
parent_path = log_path.parent
if not parent_path.exists():
parent_path.mkdir(parents=True, exist_ok=True)
logger = get_root_logger()
logger.info(100 * '-')
logger.info(f'Start running cmd\n{cmd_for_log}')
logger.info(f'Logging log to \n{log_path}')
with open(log_path, 'w', encoding='utf-8') as file_handler:
# write cmd
file_handler.write(f'Command:\n{cmd_for_log}\n')
file_handler.flush()
process_res = subprocess.Popen(
cmd_for_run,
cwd=str(Path(__file__).absolute().parent.parent),
shell=True,
stdout=file_handler,
stderr=file_handler)
process_res.wait()
return_code = process_res.returncode
if return_code != 0:
logger.error(f'Got shell return code={return_code}')
with open(log_path, 'r') as f:
content = f.read()
logger.error(f'Log error message\n{content}')
return return_code
def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
checkpoint_path: Path, work_dir: Path, device_type: str,
pytorch_metric: dict, metric_info: dict,
@ -837,11 +678,6 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
if backend_test is False and sdk_config is None:
test_type = 'convert'
metric_name_list = [str(metric) for metric in pytorch_metric]
assert len(metric_name_list) > 0
metric_all_list = [str(metric) for metric in metric_info]
metric_useless = set(metric_all_list) - set(metric_name_list)
deploy_cfg_path = Path(pipeline_info.get('deploy_config'))
backend_name = str(get_backend(str(deploy_cfg_path)).name).lower()
@ -870,55 +706,30 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
Path(checkpoint_path).stem)
backend_output_path.mkdir(parents=True, exist_ok=True)
# convert cmd string
cmd_str = 'python3 ./tools/deploy.py ' \
f'{str(deploy_cfg_path.absolute().resolve())} ' \
f'{str(model_cfg_path.absolute().resolve())} ' \
f'"{str(checkpoint_path.absolute().resolve())}" ' \
f'"{input_img_path}" ' \
f'--work-dir "{backend_output_path}" ' \
f'--device {device_type} ' \
'--log-level INFO'
# convert cmd lines
cmd_lines = [
'python3 ./tools/deploy.py',
f'{str(deploy_cfg_path.absolute().resolve())}',
f'{str(model_cfg_path.absolute().resolve())}',
f'"{str(checkpoint_path.absolute().resolve())}"',
f'"{input_img_path}" ', f'--work-dir "{backend_output_path}" ',
f'--device {device_type} ', '--log-level INFO'
]
if sdk_config is not None and test_type == 'precision':
cmd_str += ' --dump-info'
cmd_lines += ['--dump-info']
if test_img_path is not None:
cmd_str += f' --test-img {test_img_path}'
cmd_lines += [f'--test-img {test_img_path}']
if precision_type == 'int8':
calib_dataset_cfg = pipeline_info.get('calib_dataset_cfg', None)
if calib_dataset_cfg is not None:
cmd_str += f' --calib-dataset-cfg {calib_dataset_cfg}'
cmd_lines += [f'--calib-dataset-cfg {calib_dataset_cfg}']
logger.info(f'Process cmd = {cmd_str}')
convert_result = False
convert_log_path = backend_output_path.joinpath('convert_log.log')
logger.info(f'Logging conversion log to {convert_log_path} ...')
file_handler = open(convert_log_path, 'w', encoding='utf-8')
try:
# Convert the model to specific backend
process_res = subprocess.Popen(
cmd_str,
cwd=str(Path(__file__).absolute().parent.parent),
shell=True,
stdout=file_handler,
stderr=file_handler)
process_res.wait()
logger.info(f'Got shell_res = {process_res.returncode}')
# check if converted successes or not.
if process_res.returncode == 0:
convert_result = True
else:
convert_result = False
except Exception as e:
print(f'process convert error: {e}')
finally:
file_handler.close()
return_code = run_cmd(cmd_lines, convert_log_path)
convert_result = return_code == 0
logger.info(f'Got convert_result = {convert_result}')
if isinstance(backend_file_name, list):
@ -930,113 +741,70 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
else:
convert_checkpoint_path = \
str(backend_output_path.joinpath(backend_file_name))
# load deploy_cfg
deploy_cfg, model_cfg = \
load_config(str(deploy_cfg_path),
str(model_cfg_path.absolute()))
# get dataset type
if 'dataset_type' in model_cfg:
dataset_type = \
str(model_cfg.dataset_type).upper().replace('DATASET', '')
else:
dataset_type = metafile_dataset
# Test the model
if convert_result and test_type == 'precision':
# Get evaluation metric from model config
if codebase_name == 'mmseg':
metrics_eval_list = model_cfg.val_evaluator.iou_metrics
else:
metrics_eval_list = model_cfg.test_evaluator.get('metric', [])
if isinstance(metrics_eval_list, str):
# some config is using str only
metrics_eval_list = [metrics_eval_list]
# assert len(metrics_eval_list) > 0
logger.info(f'Got metrics_eval_list = {metrics_eval_list}')
if len(metrics_eval_list) == 0 and codebase_name == 'mmedit':
metrics_eval_list = ['PSNR']
# test the model metric
for metric_name in metrics_eval_list:
if backend_test:
log_path = \
gen_log_path(backend_output_path, 'backend_test.log')
get_backend_fps_metric(
deploy_cfg_path=str(deploy_cfg_path),
model_cfg_path=model_cfg_path,
convert_checkpoint_path=convert_checkpoint_path,
device_type=device_type,
eval_name=metric_name,
logger=logger,
metrics_eval_list=metrics_eval_list,
pytorch_metric=pytorch_metric,
metric_info=metric_info,
backend_name=backend_name,
precision_type=precision_type,
metric_useless=metric_useless,
convert_result=convert_result,
report_dict=report_dict,
infer_type=infer_type,
log_path=log_path,
dataset_type=dataset_type,
report_txt_path=report_txt_path,
model_name=model_name)
if backend_test:
log_path = \
gen_log_path(backend_output_path.joinpath('backend'),
'test.log')
get_backend_fps_metric(
deploy_cfg_path=str(deploy_cfg_path),
model_cfg_path=model_cfg_path,
convert_checkpoint_path=convert_checkpoint_path,
device_type=device_type,
logger=logger,
pytorch_metric=pytorch_metric,
metric_info=metric_info,
backend_name=backend_name,
precision_type=precision_type,
convert_result=convert_result,
report_dict=report_dict,
infer_type=infer_type,
log_path=log_path,
dataset_info=metafile_dataset,
report_txt_path=report_txt_path,
model_name=model_name)
if sdk_config is not None:
if sdk_config is not None:
if codebase_name == 'mmcls':
replace_top_in_pipeline_json(backend_output_path, logger)
if codebase_name == 'mmcls':
replace_top_in_pipeline_json(backend_output_path, logger)
log_path = gen_log_path(backend_output_path, 'sdk_test.log')
if backend_name == 'onnxruntime':
# sdk only support onnxruntime of cpu
device_type = 'cpu'
# sdk test
get_backend_fps_metric(
deploy_cfg_path=str(sdk_config),
model_cfg_path=model_cfg_path,
convert_checkpoint_path=str(backend_output_path),
device_type=device_type,
eval_name=metric_name,
logger=logger,
metrics_eval_list=metrics_eval_list,
pytorch_metric=pytorch_metric,
metric_info=metric_info,
backend_name=f'SDK-{backend_name}',
precision_type=precision_type,
metric_useless=metric_useless,
convert_result=convert_result,
report_dict=report_dict,
infer_type=infer_type,
log_path=log_path,
dataset_type=dataset_type,
report_txt_path=report_txt_path,
model_name=model_name)
log_path = gen_log_path(
backend_output_path.joinpath('sdk'), 'test.log')
if backend_name == 'onnxruntime':
# sdk only support onnxruntime of cpu
device_type = 'cpu'
# sdk test
get_backend_fps_metric(
deploy_cfg_path=str(sdk_config),
model_cfg_path=model_cfg_path,
convert_checkpoint_path=str(backend_output_path),
device_type=device_type,
logger=logger,
pytorch_metric=pytorch_metric,
metric_info=metric_info,
backend_name=f'SDK-{backend_name}',
precision_type=precision_type,
convert_result=convert_result,
report_dict=report_dict,
infer_type=infer_type,
log_path=log_path,
dataset_info=metafile_dataset,
report_txt_path=report_txt_path,
model_name=model_name)
else:
logger.info('Only test convert, saving to report...')
metric_list = []
metric_list = [{metric: '-'} for metric in metric_info]
fps = '-'
task_name = ''
for metric in metric_name_list:
metric_list.append({metric: '-'})
metric_task_name = metric_info.get(metric, {}).get('task_name', '')
if metric_task_name in task_name:
logger.debug('metric_task_name exist, skip for adding it...')
continue
task_name += f'{metric_task_name} | '
if ' | ' == task_name[-3:]:
task_name = task_name[:-3]
test_pass = True if convert_result else False
# update useless metric
for metric in metric_useless:
metric_list.append({metric: '-'})
test_pass = convert_result
if convert_result:
report_checkpoint = convert_checkpoint_path
else:
report_checkpoint = str(checkpoint_path)
report_checkpoint = 'x'
dataset_type = metafile_dataset['dataset']
task_name = metafile_dataset['task']
# update the report
update_report(
report_dict=report_dict,
@ -1191,9 +959,12 @@ def main():
logger.info(
f'Test specific model mode, skip {model_name_origin}...')
continue
model_metafile_info, checkpoint_save_dir, codebase_dir = \
get_model_metafile_info(global_info, models, logger)
try:
model_metafile_info, checkpoint_save_dir, codebase_dir = \
get_model_metafile_info(global_info, models, logger)
except Exception as e:
logger.error(f'Failed to get meta info {e}')
continue
for model_config in model_metafile_info:
logger.info(f'Processing test for {model_config}...')